Split backends into separate files and start implementing streaming + tool support

This commit is contained in:
Alex O'Connell
2025-09-15 22:10:09 -04:00
parent d48ccbc271
commit 53052af641
9 changed files with 1273 additions and 1134 deletions

View File

@@ -4,7 +4,6 @@ from __future__ import annotations
import logging
from typing import Final
import homeassistant.components.conversation as ha_conversation
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ATTR_ENTITY_ID, Platform
from homeassistant.core import HomeAssistant
@@ -30,8 +29,11 @@ from .const import (
BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER,
BACKEND_TYPE_OLLAMA,
)
from .conversation import LlamaCppAgent, GenericOpenAIAPIAgent, GenericOpenAIResponsesAPIAgent, \
TextGenerationWebuiAgent, LlamaCppPythonAPIAgent, OllamaAPIAgent, LocalLLMAgent
from custom_components.llama_conversation.conversation import LocalLLMAgent
from custom_components.llama_conversation.backends.llamacpp import LlamaCppAgent
from custom_components.llama_conversation.backends.generic_openai import GenericOpenAIAPIAgent, GenericOpenAIResponsesAPIAgent
from custom_components.llama_conversation.backends.tailored_openai import TextGenerationWebuiAgent, LlamaCppPythonAPIAgent
from custom_components.llama_conversation.backends.ollama import OllamaAPIAgent
type LocalLLMConfigEntry = ConfigEntry[LocalLLMAgent]

View File

@@ -0,0 +1,261 @@
"""Defines the OpenAI API compatible agents"""
from __future__ import annotations
import aiohttp
import asyncio
import datetime
import logging
from typing import List, Dict, Tuple, Optional, Any
from homeassistant.components import conversation as conversation
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from custom_components.llama_conversation.utils import format_url
from custom_components.llama_conversation.const import (
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
CONF_TEMPERATURE,
CONF_TOP_P,
CONF_REQUEST_TIMEOUT,
CONF_OPENAI_API_KEY,
CONF_REMEMBER_CONVERSATION,
CONF_REMEMBER_CONVERSATION_TIME_MINUTES,
CONF_REMOTE_USE_CHAT_ENDPOINT,
CONF_GENERIC_OPENAI_PATH,
DEFAULT_MAX_TOKENS,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_P,
DEFAULT_REQUEST_TIMEOUT,
DEFAULT_REMEMBER_CONVERSATION,
DEFAULT_REMEMBER_CONVERSATION_TIME_MINUTES,
DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
DEFAULT_GENERIC_OPENAI_PATH,
)
from custom_components.llama_conversation.conversation import LocalLLMAgent, TextGenerationResult
_LOGGER = logging.getLogger(__name__)
class BaseOpenAICompatibleAPIAgent(LocalLLMAgent):
api_host: str
api_key: str
model_name: str
async def _async_load_model(self, entry: ConfigEntry) -> None:
self.api_host = format_url(
hostname=entry.data[CONF_HOST],
port=entry.data[CONF_PORT],
ssl=entry.data[CONF_SSL],
path=""
)
self.api_key = entry.data.get(CONF_OPENAI_API_KEY)
self.model_name = entry.data.get(CONF_CHAT_MODEL)
async def _async_generate_with_parameters(self, endpoint: str, additional_params: dict) -> TextGenerationResult:
"""Generate a response using the OpenAI-compatible API"""
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
timeout = self.entry.options.get(CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT)
request_params = {
"model": self.model_name,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
}
request_params.update(additional_params)
headers = {}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
session = async_get_clientsession(self.hass)
response = None
try:
async with session.post(
f"{self.api_host}{endpoint}",
json=request_params,
timeout=timeout,
headers=headers
) as response:
response.raise_for_status()
result = await response.json()
except asyncio.TimeoutError:
return TextGenerationResult(raise_error=True, error_msg="The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities.")
except aiohttp.ClientError as err:
_LOGGER.debug(f"Err was: {err}")
_LOGGER.debug(f"Request was: {request_params}")
_LOGGER.debug(f"Result was: {response}")
return TextGenerationResult(raise_error=True, error_msg=f"Failed to communicate with the API! {err}")
_LOGGER.debug(result)
return self._extract_response(result)
def _extract_response(self, response_json: dict) -> TextGenerationResult:
raise NotImplementedError("Subclasses must implement _extract_response()")
class GenericOpenAIAPIAgent(BaseOpenAICompatibleAPIAgent):
"""Implements the OpenAPI-compatible text completion and chat completion API backends."""
def _chat_completion_params(self, conversation: List[Dict[str, str]]) -> Tuple[str, Dict[str, Any]]:
request_params = {}
api_base_path = self.entry.data.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
endpoint = f"/{api_base_path}/chat/completions"
request_params["messages"] = [ { "role": x["role"], "content": x["message"] } for x in conversation ]
return endpoint, request_params
def _completion_params(self, conversation: List[Dict[str, str]]) -> Tuple[str, Dict[str, Any]]:
request_params = {}
api_base_path = self.entry.data.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
endpoint = f"/{api_base_path}/completions"
request_params["prompt"] = self._format_prompt(conversation)
return endpoint, request_params
def _extract_response(self, response_json: dict) -> TextGenerationResult:
choice = response_json["choices"][0]
if response_json["object"] == "chat.completion":
response_text = choice["message"]["content"]
streamed = False
elif response_json["object"] == "chat.completion.chunk":
response_text = choice["message"]["content"]
streamed = True
else:
response_text = choice["text"]
streamed = False
if not streamed or streamed and choice["finish_reason"]:
if choice["finish_reason"] == "length" or choice["finish_reason"] == "content_filter":
_LOGGER.warning("Model response did not end on a stop token (unfinished sentence)")
return TextGenerationResult(
response=response_text,
stop_reason=choice["finish_reason"],
response_streamed=streamed
)
async def _async_generate(self, conversation: List[Dict[str, str]]) -> TextGenerationResult:
use_chat_api = self.entry.options.get(CONF_REMOTE_USE_CHAT_ENDPOINT, DEFAULT_REMOTE_USE_CHAT_ENDPOINT)
if use_chat_api:
endpoint, additional_params = self._chat_completion_params(conversation)
else:
endpoint, additional_params = self._completion_params(conversation)
result = await self._async_generate_with_parameters(endpoint, additional_params)
return result
class GenericOpenAIResponsesAPIAgent(BaseOpenAICompatibleAPIAgent):
"""Implements the OpenAPI-compatible Responses API backend."""
_last_response_id: str | None = None
_last_response_id_time: datetime.datetime = None
def _responses_params(self, conversation: List[Dict[str, str]]) -> Tuple[str, Dict[str, Any]]:
request_params = {}
api_base_path = self.entry.data.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
endpoint = f"/{api_base_path}/responses"
request_params["input"] = conversation[-1]["message"] # last message in the conversation is the user input
# Assign previous_response_id if relevant
if self._last_response_id and self.entry.options.get(CONF_REMEMBER_CONVERSATION, DEFAULT_REMEMBER_CONVERSATION):
# If the last response was generated recently, use it as a context
configured_memory_time: datetime.timedelta = datetime.timedelta(minutes=self.entry.options.get(CONF_REMEMBER_CONVERSATION_TIME_MINUTES, DEFAULT_REMEMBER_CONVERSATION_TIME_MINUTES))
last_conversation_age: datetime.timedelta = datetime.datetime.now() - self._last_response_id_time
_LOGGER.debug(f"Conversation ID age: {last_conversation_age}")
if last_conversation_age < configured_memory_time:
_LOGGER.debug(f"Using previous response ID {self._last_response_id} for context")
request_params["previous_response_id"] = self._last_response_id
else:
_LOGGER.debug(f"Previous response ID {self._last_response_id} is too old, not using it for context")
return endpoint, request_params
def _validate_response_payload(self, response_json: dict) -> bool:
"""
Validate that the payload given matches the expected structure for the Responses API.
API ref: https://platform.openai.com/docs/api-reference/responses/object
Returns True or raises an error
"""
required_response_keys = ["object", "output", "status", "id"]
missing_keys = [key for key in required_response_keys if key not in response_json]
if missing_keys:
raise ValueError(f"Response JSON is missing required keys: {', '.join(missing_keys)}")
if response_json["object"] != "response":
raise ValueError(f"Response JSON object is not 'response', got {response_json['object']}")
if "error" in response_json and response_json["error"] is not None:
error = response_json["error"]
_LOGGER.error(f"Response received error payload.")
if "message" not in error:
raise ValueError("Response JSON error is missing 'message' key")
raise ValueError(f"Response JSON error: {error['message']}")
return True
def _check_response_status(self, response_json: dict) -> None:
"""
Check the status of the response and logs a message if it is not 'completed'.
API ref: https://platform.openai.com/docs/api-reference/responses/object#responses_object-status
"""
if response_json["status"] != "completed":
_LOGGER.warning(f"Response status is not 'completed', got {response_json['status']}. Details: {response_json.get('incomplete_details', 'No details provided')}")
def _extract_response(self, response_json: dict) -> TextGenerationResult:
self._validate_response_payload(response_json)
self._check_response_status(response_json)
outputs = response_json["output"]
if len(outputs) > 1:
_LOGGER.warning("Received multiple outputs from the Responses API, returning the first one.")
output = outputs[0]
if not output["type"] == "message":
raise NotImplementedError(f"Response output type is not 'message', got {output['type']}")
if len(output["content"]) > 1:
_LOGGER.warning("Received multiple content items in the response output, returning the first one.")
content = output["content"][0]
output_type = content["type"]
to_return: str | None = None
if output_type == "refusal":
_LOGGER.info("Received a refusal from the Responses API.")
to_return = content["refusal"]
elif output_type == "output_text":
to_return = content["text"]
else:
raise ValueError(f"Response output content type is not expected, got {output_type}")
# Save the response_id and return the successful response.
response_id = response_json["id"]
self._last_response_id = response_id
self._last_response_id_time = datetime.datetime.now()
return to_return
async def _async_generate(self, conversation: List[Dict[str, str]]) -> TextGenerationResult:
"""Generate a response using the OpenAI-compatible Responses API"""
endpoint, additional_params = self._responses_params(conversation)
return await self._async_generate_with_parameters(endpoint, additional_params)

View File

@@ -0,0 +1,412 @@
"""Defines the llama cpp agent"""
from __future__ import annotations
import importlib
import logging
import os
import threading
import time
from typing import Any, Callable, Generator, Optional, List, Dict, AsyncIterable
from homeassistant.components import conversation as conversation
from homeassistant.components.conversation.const import DOMAIN as CONVERSATION_DOMAIN
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import callback
from homeassistant.exceptions import ConfigEntryError, HomeAssistantError
from homeassistant.helpers import llm
from homeassistant.helpers.event import async_track_state_change, async_call_later
from custom_components.llama_conversation.utils import install_llama_cpp_python, validate_llama_cpp_python_installation
from custom_components.llama_conversation.const import (
CONF_MAX_TOKENS,
CONF_PROMPT,
CONF_TEMPERATURE,
CONF_TOP_K,
CONF_TOP_P,
CONF_TYPICAL_P,
CONF_MIN_P,
CONF_DOWNLOADED_MODEL_FILE,
CONF_ENABLE_FLASH_ATTENTION,
CONF_USE_GBNF_GRAMMAR,
CONF_GBNF_GRAMMAR_FILE,
CONF_PROMPT_CACHING_ENABLED,
CONF_PROMPT_CACHING_INTERVAL,
CONF_CONTEXT_LENGTH,
CONF_BATCH_SIZE,
CONF_THREAD_COUNT,
CONF_BATCH_THREAD_COUNT,
DEFAULT_MAX_TOKENS,
DEFAULT_PROMPT,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_K,
DEFAULT_TOP_P,
DEFAULT_MIN_P,
DEFAULT_TYPICAL_P,
DEFAULT_ENABLE_FLASH_ATTENTION,
DEFAULT_USE_GBNF_GRAMMAR,
DEFAULT_GBNF_GRAMMAR_FILE,
DEFAULT_PROMPT_CACHING_ENABLED,
DEFAULT_PROMPT_CACHING_INTERVAL,
DEFAULT_CONTEXT_LENGTH,
DEFAULT_BATCH_SIZE,
DEFAULT_THREAD_COUNT,
DEFAULT_BATCH_THREAD_COUNT,
)
from custom_components.llama_conversation.conversation import LocalLLMAgent, TextGenerationResult
# make type checking work for llama-cpp-python without importing it directly at runtime
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from llama_cpp import Llama as LlamaType
else:
LlamaType = Any
_LOGGER = logging.getLogger(__name__)
class LlamaCppAgent(LocalLLMAgent):
model_path: str
llm: LlamaType
grammar: Any
llama_cpp_module: Any
remove_prompt_caching_listener: Callable
model_lock: threading.Lock
last_cache_prime: float
last_updated_entities: dict[str, float]
cache_refresh_after_cooldown: bool
loaded_model_settings: dict[str, Any]
_attr_supports_streaming = True
def _load_model(self, entry: ConfigEntry) -> None:
self.model_path = entry.data.get(CONF_DOWNLOADED_MODEL_FILE)
_LOGGER.info(
"Using model file '%s'", self.model_path
)
if not self.model_path:
raise Exception(f"Model was not found at '{self.model_path}'!")
validate_llama_cpp_python_installation()
# don't import it until now because the wheel is installed by config_flow.py
try:
self.llama_cpp_module = importlib.import_module("llama_cpp")
except ModuleNotFoundError:
# attempt to re-install llama-cpp-python if it was uninstalled for some reason
install_result = install_llama_cpp_python(self.hass.config.config_dir)
if not install_result == True:
raise ConfigEntryError("llama-cpp-python was not installed on startup and re-installing it led to an error!")
validate_llama_cpp_python_installation()
self.llama_cpp_module = importlib.import_module("llama_cpp")
Llama = getattr(self.llama_cpp_module, "Llama")
_LOGGER.debug(f"Loading model '{self.model_path}'...")
self.loaded_model_settings = {}
self.loaded_model_settings[CONF_CONTEXT_LENGTH] = entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
self.loaded_model_settings[CONF_BATCH_SIZE] = entry.options.get(CONF_BATCH_SIZE, DEFAULT_BATCH_SIZE)
self.loaded_model_settings[CONF_THREAD_COUNT] = entry.options.get(CONF_THREAD_COUNT, DEFAULT_THREAD_COUNT)
self.loaded_model_settings[CONF_BATCH_THREAD_COUNT] = entry.options.get(CONF_BATCH_THREAD_COUNT, DEFAULT_BATCH_THREAD_COUNT)
self.loaded_model_settings[CONF_ENABLE_FLASH_ATTENTION] = entry.options.get(CONF_ENABLE_FLASH_ATTENTION, DEFAULT_ENABLE_FLASH_ATTENTION)
self.llm = Llama(
model_path=self.model_path,
n_ctx=int(self.loaded_model_settings[CONF_CONTEXT_LENGTH]),
n_batch=int(self.loaded_model_settings[CONF_BATCH_SIZE]),
n_threads=int(self.loaded_model_settings[CONF_THREAD_COUNT]),
n_threads_batch=int(self.loaded_model_settings[CONF_BATCH_THREAD_COUNT]),
flash_attn=self.loaded_model_settings[CONF_ENABLE_FLASH_ATTENTION],
)
_LOGGER.debug("Model loaded")
self.grammar = None
if entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR):
self._load_grammar(entry.options.get(CONF_GBNF_GRAMMAR_FILE, DEFAULT_GBNF_GRAMMAR_FILE))
# TODO: check about disk caching
# self.llm.set_cache(self.llama_cpp_module.LlamaDiskCache(
# capacity_bytes=(512 * 10e8),
# cache_dir=os.path.join(self.hass.config.media_dirs.get("local", self.hass.config.path("media")), "kv_cache")
# ))
self.remove_prompt_caching_listener = None
self.last_cache_prime = None
self.last_updated_entities = {}
self.cache_refresh_after_cooldown = False
self.model_lock = threading.Lock()
self.loaded_model_settings[CONF_PROMPT_CACHING_ENABLED] = entry.options.get(CONF_PROMPT_CACHING_ENABLED, DEFAULT_PROMPT_CACHING_ENABLED)
if self.loaded_model_settings[CONF_PROMPT_CACHING_ENABLED]:
@callback
async def enable_caching_after_startup(_now) -> None:
self._set_prompt_caching(enabled=True)
await self._async_cache_prompt(None, None, None)
async_call_later(self.hass, 5.0, enable_caching_after_startup)
def _load_grammar(self, filename: str):
LlamaGrammar = getattr(self.llama_cpp_module, "LlamaGrammar")
_LOGGER.debug(f"Loading grammar {filename}...")
try:
with open(os.path.join(os.path.dirname(__file__), filename)) as f:
grammar_str = "".join(f.readlines())
self.grammar = LlamaGrammar.from_string(grammar_str)
self.loaded_model_settings[CONF_GBNF_GRAMMAR_FILE] = filename
_LOGGER.debug("Loaded grammar")
except Exception:
_LOGGER.exception("Failed to load grammar!")
self.grammar = None
def _update_options(self):
LocalLLMAgent._update_options(self)
model_reloaded = False
if self.loaded_model_settings[CONF_CONTEXT_LENGTH] != self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH) or \
self.loaded_model_settings[CONF_BATCH_SIZE] != self.entry.options.get(CONF_BATCH_SIZE, DEFAULT_BATCH_SIZE) or \
self.loaded_model_settings[CONF_THREAD_COUNT] != self.entry.options.get(CONF_THREAD_COUNT, DEFAULT_THREAD_COUNT) or \
self.loaded_model_settings[CONF_BATCH_THREAD_COUNT] != self.entry.options.get(CONF_BATCH_THREAD_COUNT, DEFAULT_BATCH_THREAD_COUNT) or \
self.loaded_model_settings[CONF_ENABLE_FLASH_ATTENTION] != self.entry.options.get(CONF_ENABLE_FLASH_ATTENTION, DEFAULT_ENABLE_FLASH_ATTENTION):
_LOGGER.debug(f"Reloading model '{self.model_path}'...")
self.loaded_model_settings[CONF_CONTEXT_LENGTH] = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
self.loaded_model_settings[CONF_BATCH_SIZE] = self.entry.options.get(CONF_BATCH_SIZE, DEFAULT_BATCH_SIZE)
self.loaded_model_settings[CONF_THREAD_COUNT] = self.entry.options.get(CONF_THREAD_COUNT, DEFAULT_THREAD_COUNT)
self.loaded_model_settings[CONF_BATCH_THREAD_COUNT] = self.entry.options.get(CONF_BATCH_THREAD_COUNT, DEFAULT_BATCH_THREAD_COUNT)
self.loaded_model_settings[CONF_ENABLE_FLASH_ATTENTION] = self.entry.options.get(CONF_ENABLE_FLASH_ATTENTION, DEFAULT_ENABLE_FLASH_ATTENTION)
Llama = getattr(self.llama_cpp_module, "Llama")
self.llm = Llama(
model_path=self.model_path,
n_ctx=int(self.loaded_model_settings[CONF_CONTEXT_LENGTH]),
n_batch=int(self.loaded_model_settings[CONF_BATCH_SIZE]),
n_threads=int(self.loaded_model_settings[CONF_THREAD_COUNT]),
n_threads_batch=int(self.loaded_model_settings[CONF_BATCH_THREAD_COUNT]),
flash_attn=self.loaded_model_settings[CONF_ENABLE_FLASH_ATTENTION],
)
_LOGGER.debug("Model loaded")
model_reloaded = True
if self.entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR):
current_grammar = self.entry.options.get(CONF_GBNF_GRAMMAR_FILE, DEFAULT_GBNF_GRAMMAR_FILE)
if not self.grammar or self.loaded_model_settings[CONF_GBNF_GRAMMAR_FILE] != current_grammar:
self._load_grammar(current_grammar)
else:
self.grammar = None
if self.entry.options.get(CONF_PROMPT_CACHING_ENABLED, DEFAULT_PROMPT_CACHING_ENABLED):
self._set_prompt_caching(enabled=True)
if self.loaded_model_settings[CONF_PROMPT_CACHING_ENABLED] != self.entry.options.get(CONF_PROMPT_CACHING_ENABLED, DEFAULT_PROMPT_CACHING_ENABLED) or \
model_reloaded:
self.loaded_model_settings[CONF_PROMPT_CACHING_ENABLED] = self.entry.options.get(CONF_PROMPT_CACHING_ENABLED, DEFAULT_PROMPT_CACHING_ENABLED)
async def cache_current_prompt(_now):
await self._async_cache_prompt(None, None, None)
async_call_later(self.hass, 1.0, cache_current_prompt)
else:
self._set_prompt_caching(enabled=False)
def _async_get_exposed_entities(self) -> tuple[dict[str, str], list[str]]:
"""Takes the super class function results and sorts the entities with the recently updated at the end"""
entities, domains = LocalLLMAgent._async_get_exposed_entities(self)
# ignore sorting if prompt caching is disabled
if not self.entry.options.get(CONF_PROMPT_CACHING_ENABLED, DEFAULT_PROMPT_CACHING_ENABLED):
return entities, domains
entity_order = { name: None for name in entities.keys() }
entity_order.update(self.last_updated_entities)
def sort_key(item):
item_name, last_updated = item
# Handle cases where last updated is None
if last_updated is None:
return (False, '', item_name)
else:
return (True, last_updated, '')
# Sort the items based on the sort_key function
sorted_items = sorted(list(entity_order.items()), key=sort_key)
_LOGGER.debug(f"sorted_items: {sorted_items}")
sorted_entities = {}
for item_name, _ in sorted_items:
sorted_entities[item_name] = entities[item_name]
return sorted_entities, domains
def _set_prompt_caching(self, *, enabled=True):
if enabled and not self.remove_prompt_caching_listener:
_LOGGER.info("enabling prompt caching...")
entity_ids = [
state.entity_id for state in self.hass.states.async_all() \
if async_should_expose(self.hass, CONVERSATION_DOMAIN, state.entity_id)
]
_LOGGER.debug(f"watching entities: {entity_ids}")
self.remove_prompt_caching_listener = async_track_state_change(self.hass, entity_ids, self._async_cache_prompt)
elif not enabled and self.remove_prompt_caching_listener:
_LOGGER.info("disabling prompt caching...")
self.remove_prompt_caching_listener()
@callback
async def _async_cache_prompt(self, entity, old_state, new_state):
refresh_start = time.time()
# track last update time so we can sort the context efficiently
if entity:
self.last_updated_entities[entity] = refresh_start
llm_api: llm.APIInstance | None = None
if self.entry.options.get(CONF_LLM_HASS_API):
try:
llm_api = await llm.async_get_api(
self.hass, self.entry.options[CONF_LLM_HASS_API]
)
except HomeAssistantError:
_LOGGER.exception("Failed to get LLM API when caching prompt!")
return
_LOGGER.debug(f"refreshing cached prompt because {entity} changed...")
await self.hass.async_add_executor_job(self._cache_prompt, llm_api)
refresh_end = time.time()
_LOGGER.debug(f"cache refresh took {(refresh_end - refresh_start):.2f} sec")
def _cache_prompt(self, llm_api: llm.APIInstance | None) -> None:
# if a refresh is already scheduled then exit
if self.cache_refresh_after_cooldown:
return
# if we are inside the cooldown period, request a refresh and exit
current_time = time.time()
fastest_prime_interval = self.entry.options.get(CONF_PROMPT_CACHING_INTERVAL, DEFAULT_PROMPT_CACHING_INTERVAL)
if self.last_cache_prime and current_time - self.last_cache_prime < fastest_prime_interval:
self.cache_refresh_after_cooldown = True
return
# try to acquire the lock, if we are still running for some reason, request a refresh and exit
lock_acquired = self.model_lock.acquire(False)
if not lock_acquired:
self.cache_refresh_after_cooldown = True
return
try:
raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
prompt = self._format_prompt([
{ "role": "system", "message": self._generate_system_prompt(raw_prompt, llm_api)},
{ "role": "user", "message": "" }
], include_generation_prompt=False)
input_tokens = self.llm.tokenize(
prompt.encode(), add_bos=False
)
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
top_k = int(self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K))
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
grammar = self.grammar if self.entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR) else None
_LOGGER.debug(f"Options: {self.entry.options}")
_LOGGER.debug(f"Processing {len(input_tokens)} input tokens...")
# grab just one token. should prime the kv cache with the system prompt
next(self.llm.generate(
input_tokens,
temp=temperature,
top_k=top_k,
top_p=top_p,
grammar=grammar
))
self.last_cache_prime = time.time()
finally:
self.model_lock.release()
# schedule a refresh using async_call_later
# if the flag is set after the delay then we do another refresh
@callback
async def refresh_if_requested(_now):
if self.cache_refresh_after_cooldown:
self.cache_refresh_after_cooldown = False
refresh_start = time.time()
_LOGGER.debug(f"refreshing cached prompt after cooldown...")
await self.hass.async_add_executor_job(self._cache_prompt)
refresh_end = time.time()
_LOGGER.debug(f"cache refresh took {(refresh_end - refresh_start):.2f} sec")
refresh_delay = self.entry.options.get(CONF_PROMPT_CACHING_INTERVAL, DEFAULT_PROMPT_CACHING_INTERVAL)
async_call_later(self.hass, float(refresh_delay), refresh_if_requested)
def _generate_stream(self, conversation: List[Dict[str, str]], user_input: conversation.ConversationInput, chat_log: conversation.ChatLog) -> TextGenerationResult:
prompt = self._format_prompt(conversation)
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
top_k = int(self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K))
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
min_p = self.entry.options.get(CONF_MIN_P, DEFAULT_MIN_P)
typical_p = self.entry.options.get(CONF_TYPICAL_P, DEFAULT_TYPICAL_P)
_LOGGER.debug(f"Options: {self.entry.options}")
with self.model_lock:
input_tokens = self.llm.tokenize(
prompt.encode(), add_bos=False
)
context_len = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
if len(input_tokens) >= context_len:
num_entities = len(self._async_get_exposed_entities()[0])
context_size = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
self._warn_context_size()
raise Exception(f"The model failed to produce a result because too many devices are exposed ({num_entities} devices) for the context size ({context_size} tokens)!")
if len(input_tokens) + max_tokens >= context_len:
self._warn_context_size()
_LOGGER.debug(f"Processing {len(input_tokens)} input tokens...")
output_tokens = self.llm.generate(
input_tokens,
temp=temperature,
top_k=top_k,
top_p=top_p,
min_p=min_p,
typical_p=typical_p,
grammar=self.grammar
)
result_tokens = []
async def async_iterator():
num_tokens = 0
for token in output_tokens:
result_tokens.append(token)
yield TextGenerationResult(response=self.llm.detokenize([token]).decode(), response_streamed=True)
if token == self.llm.token_eos():
break
if len(result_tokens) >= max_tokens:
break
num_tokens += 1
self._transform_result_stream(async_iterator(), user_input=user_input, chat_log=chat_log)
response = TextGenerationResult(
response=self.llm.detokenize(result_tokens).decode(),
response_streamed=True,
)

View File

@@ -0,0 +1,198 @@
"""Defines the ollama compatible agent"""
from __future__ import annotations
import aiohttp
import asyncio
import json
import logging
from typing import Optional, Tuple, Dict, List, Any
from homeassistant.components import conversation as conversation
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL
from homeassistant.exceptions import ConfigEntryNotReady
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from custom_components.llama_conversation.utils import format_url
from custom_components.llama_conversation.const import (
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
CONF_TEMPERATURE,
CONF_TOP_K,
CONF_TOP_P,
CONF_TYPICAL_P,
CONF_REQUEST_TIMEOUT,
CONF_OPENAI_API_KEY,
CONF_REMOTE_USE_CHAT_ENDPOINT,
CONF_OLLAMA_KEEP_ALIVE_MIN,
CONF_OLLAMA_JSON_MODE,
CONF_CONTEXT_LENGTH,
DEFAULT_MAX_TOKENS,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_K,
DEFAULT_TOP_P,
DEFAULT_TYPICAL_P,
DEFAULT_REQUEST_TIMEOUT,
DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
DEFAULT_OLLAMA_KEEP_ALIVE_MIN,
DEFAULT_OLLAMA_JSON_MODE,
DEFAULT_CONTEXT_LENGTH,
)
from custom_components.llama_conversation.conversation import LocalLLMAgent, TextGenerationResult
_LOGGER = logging.getLogger(__name__)
class OllamaAPIAgent(LocalLLMAgent):
api_host: str
api_key: Optional[str]
model_name: Optional[str]
async def _async_load_model(self, entry: ConfigEntry) -> None:
self.api_host = format_url(
hostname=entry.data[CONF_HOST],
port=entry.data[CONF_PORT],
ssl=entry.data[CONF_SSL],
path=""
)
self.api_key = entry.data.get(CONF_OPENAI_API_KEY)
self.model_name = entry.data.get(CONF_CHAT_MODEL)
# ollama handles loading for us so just make sure the model is available
try:
headers = {}
session = async_get_clientsession(self.hass)
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
async with session.get(
f"{self.api_host}/api/tags",
headers=headers,
) as response:
response.raise_for_status()
currently_downloaded_result = await response.json()
except Exception as ex:
_LOGGER.debug("Connection error was: %s", repr(ex))
raise ConfigEntryNotReady("There was a problem connecting to the remote server") from ex
model_names = [ x["name"] for x in currently_downloaded_result["models"] ]
if ":" in self.model_name:
if not any([ name == self.model_name for name in model_names]):
raise ConfigEntryNotReady(f"Ollama server does not have the provided model: {self.model_name}")
elif not any([ name.split(":")[0] == self.model_name for name in model_names ]):
raise ConfigEntryNotReady(f"Ollama server does not have the provided model: {self.model_name}")
def _chat_completion_params(self, conversation: List[Dict[str, str]]) -> Tuple[str, Dict]:
request_params = {}
endpoint = "/api/chat"
request_params["messages"] = [ { "role": x["role"], "content": x["message"] } for x in conversation ]
return endpoint, request_params
def _completion_params(self, conversation: List[Dict[str, str]]) -> Tuple[str, Dict[str, Any]]:
request_params = {}
endpoint = "/api/generate"
request_params["prompt"] = self._format_prompt(conversation)
request_params["raw"] = True # ignore prompt template
return endpoint, request_params
def _extract_response(self, response_json: Dict) -> TextGenerationResult:
# TODO: this doesn't work because ollama caches prompts and doesn't always return the full prompt length
# context_len = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
# max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
# if response_json["prompt_eval_count"] + max_tokens > context_len:
# self._warn_context_size()
if "response" in response_json:
response = response_json["response"]
tool_calls = None
stop_reason = None
if response_json["done"] not in ["true", True]:
_LOGGER.warning("Model response did not end on a stop token (unfinished sentence)")
else:
response = response_json["message"]["content"]
tool_calls = response_json["message"].get("tool_calls")
stop_reason = response_json.get("done_reason")
return TextGenerationResult(
response=response, tool_calls=tool_calls, stop_reason=stop_reason, response_streamed=True
)
async def _async_generate(self, conversation: List[Dict[str, str]]) -> TextGenerationResult:
context_length = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
top_k = self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K)
typical_p = self.entry.options.get(CONF_TYPICAL_P, DEFAULT_TYPICAL_P)
timeout = self.entry.options.get(CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT)
keep_alive = self.entry.options.get(CONF_OLLAMA_KEEP_ALIVE_MIN, DEFAULT_OLLAMA_KEEP_ALIVE_MIN)
use_chat_api = self.entry.options.get(CONF_REMOTE_USE_CHAT_ENDPOINT, DEFAULT_REMOTE_USE_CHAT_ENDPOINT)
json_mode = self.entry.options.get(CONF_OLLAMA_JSON_MODE, DEFAULT_OLLAMA_JSON_MODE)
request_params = {
"model": self.model_name,
"stream": True,
"keep_alive": f"{keep_alive}m", # prevent ollama from unloading the model
"options": {
"num_ctx": context_length,
"top_p": top_p,
"top_k": top_k,
"typical_p": typical_p,
"temperature": temperature,
"num_preDict": max_tokens,
}
}
if json_mode:
request_params["format"] = "json"
if use_chat_api:
endpoint, additional_params = self._chat_completion_params(conversation)
else:
endpoint, additional_params = self._completion_params(conversation)
request_params.update(additional_params)
headers = {}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
session = async_get_clientsession(self.hass)
response = None
result = TextGenerationResult(
response="", response_streamed=True
)
try:
async with session.post(
f"{self.api_host}{endpoint}",
json=request_params,
timeout=timeout,
headers=headers
) as response:
response.raise_for_status()
while True:
chunk = await response.content.readline()
if not chunk:
break
parsed_chunk = self._extract_response(json.loads(chunk))
result.response += parsed_chunk.response
result.stop_reason = parsed_chunk.stop_reason
result.tool_calls = parsed_chunk.tool_calls
except asyncio.TimeoutError:
return TextGenerationResult(raise_error=True, error_msg="The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities.")
except aiohttp.ClientError as err:
_LOGGER.debug(f"Err was: {err}")
_LOGGER.debug(f"Request was: {request_params}")
_LOGGER.debug(f"Result was: {response}")
return TextGenerationResult(raise_error=True, error_msg=f"Failed to communicate with the API! {err}")
_LOGGER.debug(result)
return result

View File

@@ -0,0 +1,171 @@
"""Defines the various openai-like agents"""
from __future__ import annotations
import logging
import os
from typing import Optional, Tuple, Dict, List, Any
from dataclasses import dataclass
from homeassistant.config_entries import ConfigEntry
from homeassistant.exceptions import ConfigEntryNotReady
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from custom_components.llama_conversation.const import (
CONF_MAX_TOKENS,
CONF_TOP_K,
CONF_TYPICAL_P,
CONF_MIN_P,
CONF_PROMPT_TEMPLATE,
CONF_USE_GBNF_GRAMMAR,
CONF_TEXT_GEN_WEBUI_PRESET,
CONF_TEXT_GEN_WEBUI_ADMIN_KEY,
CONF_TEXT_GEN_WEBUI_CHAT_MODE,
CONF_CONTEXT_LENGTH,
DEFAULT_MAX_TOKENS,
DEFAULT_TOP_K,
DEFAULT_MIN_P,
DEFAULT_TYPICAL_P,
DEFAULT_PROMPT_TEMPLATE,
DEFAULT_USE_GBNF_GRAMMAR,
DEFAULT_GBNF_GRAMMAR_FILE,
DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
DEFAULT_CONTEXT_LENGTH,
TEXT_GEN_WEBUI_CHAT_MODE_CHAT,
TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT,
TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT,
)
from custom_components.llama_conversation.conversation import TextGenerationResult
from custom_components.llama_conversation.backends.generic_openai import GenericOpenAIAPIAgent
_LOGGER = logging.getLogger(__name__)
class TextGenerationWebuiAgent(GenericOpenAIAPIAgent):
admin_key: str
async def _async_load_model(self, entry: ConfigEntry) -> None:
await super()._async_load_model(entry)
self.admin_key = entry.data.get(CONF_TEXT_GEN_WEBUI_ADMIN_KEY, self.api_key)
try:
headers = {}
session = async_get_clientsession(self.hass)
if self.admin_key:
headers["Authorization"] = f"Bearer {self.admin_key}"
async with session.get(
f"{self.api_host}/v1/internal/model/info",
headers=headers
) as response:
response.raise_for_status()
currently_loaded_result = await response.json()
loaded_model = currently_loaded_result["model_name"]
if loaded_model == self.model_name:
_LOGGER.info(f"Model {self.model_name} is already loaded on the remote backend.")
return
else:
_LOGGER.info(f"Model is not {self.model_name} loaded on the remote backend. Loading it now...")
async with session.post(
f"{self.api_host}/v1/internal/model/load",
json={
"model_name": self.model_name,
# TODO: expose arguments to the user in home assistant UI
# "args": {},
},
headers=headers
) as response:
response.raise_for_status()
except Exception as ex:
_LOGGER.debug("Connection error was: %s", repr(ex))
raise ConfigEntryNotReady("There was a problem connecting to the remote server") from ex
def _chat_completion_params(self, conversation: List[Dict[str, str]]) -> Tuple[str, Dict]:
preset = self.entry.options.get(CONF_TEXT_GEN_WEBUI_PRESET)
chat_mode = self.entry.options.get(CONF_TEXT_GEN_WEBUI_CHAT_MODE, DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE)
endpoint, request_params = super()._chat_completion_params(conversation)
request_params["mode"] = chat_mode
if chat_mode == TEXT_GEN_WEBUI_CHAT_MODE_CHAT or chat_mode == TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT:
if preset:
request_params["character"] = preset
elif chat_mode == TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT:
request_params["instruction_template"] = self.entry.options.get(CONF_PROMPT_TEMPLATE, DEFAULT_PROMPT_TEMPLATE)
request_params["truncation_length"] = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
request_params["top_k"] = self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K)
request_params["min_p"] = self.entry.options.get(CONF_MIN_P, DEFAULT_MIN_P)
request_params["typical_p"] = self.entry.options.get(CONF_TYPICAL_P, DEFAULT_TYPICAL_P)
return endpoint, request_params
def _completion_params(self, conversation: List[Dict[str, str]]) -> Tuple[str, Dict[str, Any]]:
preset = self.entry.options.get(CONF_TEXT_GEN_WEBUI_PRESET)
endpoint, request_params = super()._completion_params(conversation)
if preset:
request_params["preset"] = preset
request_params["truncation_length"] = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
request_params["top_k"] = self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K)
request_params["min_p"] = self.entry.options.get(CONF_MIN_P, DEFAULT_MIN_P)
request_params["typical_p"] = self.entry.options.get(CONF_TYPICAL_P, DEFAULT_TYPICAL_P)
return endpoint, request_params
def _extract_response(self, response_json: dict) -> TextGenerationResult:
choices = response_json["choices"]
if choices[0]["finish_reason"] != "stop":
_LOGGER.warning("Model response did not end on a stop token (unfinished sentence)")
context_len = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
if response_json["usage"]["prompt_tokens"] + max_tokens > context_len:
self._warn_context_size()
# text-gen-webui has a typo where it is 'chat.completions' not 'chat.completion'
if response_json["object"] == "chat.completions":
return choices[0]["message"]["content"]
else:
return choices[0]["text"]
class LlamaCppPythonAPIAgent(GenericOpenAIAPIAgent):
"""https://llama-cpp-python.readthedocs.io/en/latest/server/"""
grammar: str
async def _async_load_model(self, entry: ConfigEntry):
await super()._async_load_model(entry)
return await self.hass.async_add_executor_job(
self._load_model, entry
)
def _load_model(self, entry: ConfigEntry):
with open(os.path.join(os.path.dirname(__file__), DEFAULT_GBNF_GRAMMAR_FILE)) as f:
self.grammar = "".join(f.readlines())
def _chat_completion_params(self, conversation: List[Dict[str, str]]) -> Tuple[str, Dict[str, Any]]:
top_k = int(self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K))
endpoint, request_params = super()._chat_completion_params(conversation)
request_params["top_k"] = top_k
if self.entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR):
request_params["grammar"] = self.grammar
return endpoint, request_params
def _completion_params(self, conversation: List[Dict[str, str]]) -> Tuple[str, Dict[str, Any]]:
top_k = int(self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K))
endpoint, request_params = super()._completion_params(conversation)
request_params["top_k"] = top_k
if self.entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR):
request_params["grammar"] = self.grammar
return endpoint, request_params

File diff suppressed because it is too large Load Diff

View File

@@ -1,9 +1,9 @@
# types from Home Assistant
homeassistant>=2024.6.1
homeassistant>=2024.7.0
hassil
home-assistant-intents
# testing requirements
pytest
pytest-asyncio
pytest-homeassistant-custom-component
pytest-homeassistant-custom-component==0.13.260

View File

@@ -1,5 +1,4 @@
-r requirements.txt
-r requirements-dev.txt
-r custom_components/requirements.txt
-r custom_components/requirements-dev.txt
-r data/requirements.txt

View File

@@ -4,7 +4,10 @@ import pytest
import jinja2
from unittest.mock import patch, MagicMock, PropertyMock, AsyncMock, ANY
from custom_components.llama_conversation.conversation import LlamaCppAgent, OllamaAPIAgent, TextGenerationWebuiAgent, GenericOpenAIAPIAgent
from custom_components.llama_conversation.backends.llamacpp import LlamaCppAgent
from custom_components.llama_conversation.backends.ollama import OllamaAPIAgent
from custom_components.llama_conversation.backends.tailored_openai import TextGenerationWebuiAgent
from custom_components.llama_conversation.backends.generic_openai import GenericOpenAIAPIAgent
from custom_components.llama_conversation.const import (
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
@@ -106,7 +109,8 @@ class MockConfigEntry:
def __init__(self, entry_id='test_entry_id', data={}, options={}):
self.entry_id = entry_id
self.data = WarnDict(data)
self.options = WarnDict(options)
# Use a mutable dict for options in tests
self.options = WarnDict(dict(options))
@pytest.fixture
@@ -138,14 +142,16 @@ def local_llama_agent_fixture(config_entry, hass, enable_custom_integrations):
patch.object(LlamaCppAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \
patch.object(APIInstance, 'async_call_tool') as call_tool_mock, \
patch('homeassistant.helpers.template.Template') as template_mock, \
patch('custom_components.llama_conversation.agent.importlib.import_module') as import_module_mock, \
patch('custom_components.llama_conversation.agent.install_llama_cpp_python') as install_llama_cpp_python_mock:
patch('custom_components.llama_conversation.backends.llamacpp.importlib.import_module') as import_module_mock, \
patch('custom_components.llama_conversation.utils.importlib.import_module') as import_module_mock_2, \
patch('custom_components.llama_conversation.utils.install_llama_cpp_python') as install_llama_cpp_python_mock:
entry_mock.return_value = config_entry
llama_instance_mock = MagicMock()
llama_class_mock = MagicMock()
llama_class_mock.return_value = llama_instance_mock
import_module_mock.return_value = MagicMock(Llama=llama_class_mock)
import_module_mock_2.return_value = MagicMock(Llama=llama_class_mock)
install_llama_cpp_python_mock.return_value = True
get_exposed_entities_mock.return_value = (
{
@@ -194,7 +200,7 @@ async def test_local_llama_agent(local_llama_agent_fixture):
# invoke the conversation agent
conversation_id = "test-conversation"
result = await local_llama_agent.async_process(ConversationInput(
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en"
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
))
# assert on results + check side effects
@@ -249,7 +255,7 @@ async def test_local_llama_agent(local_llama_agent_fixture):
# do another turn of the same conversation
result = await local_llama_agent.async_process(ConversationInput(
"turn off the office lights", MagicMock(), conversation_id, None, "en"
"turn off the office lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
))
assert result.response.speech['plain']['speech'] == "I am saying something!"
@@ -272,8 +278,7 @@ def ollama_agent_fixture(config_entry, hass, enable_custom_integrations):
patch.object(OllamaAPIAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \
patch.object(APIInstance, 'async_call_tool') as call_tool_mock, \
patch('homeassistant.helpers.template.Template') as template_mock, \
patch('custom_components.llama_conversation.agent.requests.get') as requests_get_mock, \
patch('custom_components.llama_conversation.agent.requests.post') as requests_post_mock:
patch('custom_components.llama_conversation.backends.ollama.async_get_clientsession') as get_clientsession:
entry_mock.return_value = config_entry
get_exposed_entities_mock.return_value = (
@@ -288,7 +293,7 @@ def ollama_agent_fixture(config_entry, hass, enable_custom_integrations):
response_mock = MagicMock()
response_mock.json.return_value = { "models": [ {"name": config_entry.data[CONF_CHAT_MODEL] }] }
requests_get_mock.return_value = response_mock
get_clientsession.get.return_value = response_mock
call_tool_mock.return_value = {"result": "success"}
@@ -298,8 +303,8 @@ def ollama_agent_fixture(config_entry, hass, enable_custom_integrations):
)
all_mocks = {
"requests_get": requests_get_mock,
"requests_post": requests_post_mock
"requests_get": get_clientsession.get,
"requests_post": get_clientsession.post
}
yield agent_obj, all_mocks
@@ -339,7 +344,7 @@ async def test_ollama_agent(ollama_agent_fixture):
# invoke the conversation agent
conversation_id = "test-conversation"
result = await ollama_agent.async_process(ConversationInput(
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en"
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
))
# assert on results + check side effects
@@ -384,7 +389,7 @@ async def test_ollama_agent(ollama_agent_fixture):
# do another turn of the same conversation
result = await ollama_agent.async_process(ConversationInput(
"turn off the office lights", MagicMock(), conversation_id, None, "en"
"turn off the office lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
))
assert result.response.speech['plain']['speech'] == "I am saying something!"
@@ -418,8 +423,7 @@ def text_generation_webui_agent_fixture(config_entry, hass, enable_custom_integr
patch.object(TextGenerationWebuiAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \
patch.object(APIInstance, 'async_call_tool') as call_tool_mock, \
patch('homeassistant.helpers.template.Template') as template_mock, \
patch('custom_components.llama_conversation.agent.requests.get') as requests_get_mock, \
patch('custom_components.llama_conversation.agent.requests.post') as requests_post_mock:
patch('custom_components.llama_conversation.backends.tailored_openai.async_get_clientsession') as get_clientsession_mock:
entry_mock.return_value = config_entry
get_exposed_entities_mock.return_value = (
@@ -434,7 +438,7 @@ def text_generation_webui_agent_fixture(config_entry, hass, enable_custom_integr
response_mock = MagicMock()
response_mock.json.return_value = { "model_name": config_entry.data[CONF_CHAT_MODEL] }
requests_get_mock.return_value = response_mock
get_clientsession_mock.get.return_value = response_mock
call_tool_mock.return_value = {"result": "success"}
@@ -444,8 +448,8 @@ def text_generation_webui_agent_fixture(config_entry, hass, enable_custom_integr
)
all_mocks = {
"requests_get": requests_get_mock,
"requests_post": requests_post_mock
"requests_get": get_clientsession_mock.get,
"requests_post": get_clientsession_mock.post
}
yield agent_obj, all_mocks
@@ -490,7 +494,7 @@ async def test_text_generation_webui_agent(text_generation_webui_agent_fixture):
# invoke the conversation agent
conversation_id = "test-conversation"
result = await text_generation_webui_agent.async_process(ConversationInput(
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en"
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
))
# assert on results + check side effects
@@ -521,7 +525,7 @@ async def test_text_generation_webui_agent(text_generation_webui_agent_fixture):
# do another turn of the same conversation and use a preset
result = await text_generation_webui_agent.async_process(ConversationInput(
"turn off the office lights", MagicMock(), conversation_id, None, "en"
"turn off the office lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
))
assert result.response.speech['plain']['speech'] == "I am saying something!"
@@ -588,7 +592,7 @@ async def test_text_generation_webui_agent(text_generation_webui_agent_fixture):
# do another turn of the same conversation but the chat endpoint
result = await text_generation_webui_agent.async_process(ConversationInput(
"turn off the office lights", MagicMock(), conversation_id, None, "en"
"turn off the office lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
))
assert result.response.speech['plain']['speech'] == "I am saying something!"
@@ -611,69 +615,6 @@ async def test_text_generation_webui_agent(text_generation_webui_agent_fixture):
timeout=text_generation_webui_agent.entry.options[CONF_REQUEST_TIMEOUT]
)
# reset mock stats
for mock in all_mocks.values():
mock.reset_mock()
text_generation_webui_agent.entry.options[CONF_TEXT_GEN_WEBUI_PRESET] = "Some Character"
# do another turn of the same conversation and use a preset
result = await text_generation_webui_agent.async_process(ConversationInput(
"turn off the office lights", MagicMock(), conversation_id, None, "en"
))
assert result.response.speech['plain']['speech'] == "I am saying something!"
all_mocks["requests_post"].assert_called_once_with(
"http://localhost:5000/v1/chat/completions",
json={
"model": text_generation_webui_agent.entry.data[CONF_CHAT_MODEL],
"top_p": text_generation_webui_agent.entry.options[CONF_TOP_P],
"top_k": text_generation_webui_agent.entry.options[CONF_TOP_K],
"temperature": text_generation_webui_agent.entry.options[CONF_TEMPERATURE],
"min_p": text_generation_webui_agent.entry.options[CONF_MIN_P],
"typical_p": text_generation_webui_agent.entry.options[CONF_TYPICAL_P],
"truncation_length": text_generation_webui_agent.entry.options[CONF_CONTEXT_LENGTH],
"max_tokens": text_generation_webui_agent.entry.options[CONF_MAX_TOKENS],
"mode": text_generation_webui_agent.entry.options[CONF_TEXT_GEN_WEBUI_CHAT_MODE],
"character": "Some Character",
"messages": ANY
},
headers={ "Authorization": "Bearer OpenAI-API-Key" },
timeout=text_generation_webui_agent.entry.options[CONF_REQUEST_TIMEOUT]
)
# reset mock stats
for mock in all_mocks.values():
mock.reset_mock()
text_generation_webui_agent.entry.options[CONF_TEXT_GEN_WEBUI_CHAT_MODE] = TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT
# do another turn of the same conversation and use instruct mode
result = await text_generation_webui_agent.async_process(ConversationInput(
"turn off the office lights", MagicMock(), conversation_id, None, "en"
))
assert result.response.speech['plain']['speech'] == "I am saying something!"
all_mocks["requests_post"].assert_called_once_with(
"http://localhost:5000/v1/chat/completions",
json={
"model": text_generation_webui_agent.entry.data[CONF_CHAT_MODEL],
"top_p": text_generation_webui_agent.entry.options[CONF_TOP_P],
"top_k": text_generation_webui_agent.entry.options[CONF_TOP_K],
"min_p": text_generation_webui_agent.entry.options[CONF_MIN_P],
"typical_p": text_generation_webui_agent.entry.options[CONF_TYPICAL_P],
"temperature": text_generation_webui_agent.entry.options[CONF_TEMPERATURE],
"truncation_length": text_generation_webui_agent.entry.options[CONF_CONTEXT_LENGTH],
"max_tokens": text_generation_webui_agent.entry.options[CONF_MAX_TOKENS],
"mode": text_generation_webui_agent.entry.options[CONF_TEXT_GEN_WEBUI_CHAT_MODE],
"instruction_template": "chatml",
"messages": ANY
},
headers={ "Authorization": "Bearer OpenAI-API-Key" },
timeout=text_generation_webui_agent.entry.options[CONF_REQUEST_TIMEOUT]
)
@pytest.fixture
def generic_openai_agent_fixture(config_entry, hass, enable_custom_integrations):
@@ -682,8 +623,7 @@ def generic_openai_agent_fixture(config_entry, hass, enable_custom_integrations)
patch.object(GenericOpenAIAPIAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \
patch.object(APIInstance, 'async_call_tool') as call_tool_mock, \
patch('homeassistant.helpers.template.Template') as template_mock, \
patch('custom_components.llama_conversation.agent.requests.get') as requests_get_mock, \
patch('custom_components.llama_conversation.agent.requests.post') as requests_post_mock:
patch('custom_components.llama_conversation.backends.generic_openai.async_get_clientsession') as get_clientsession_mock:
entry_mock.return_value = config_entry
get_exposed_entities_mock.return_value = (
@@ -704,8 +644,8 @@ def generic_openai_agent_fixture(config_entry, hass, enable_custom_integrations)
)
all_mocks = {
"requests_get": requests_get_mock,
"requests_post": requests_post_mock
"requests_get": get_clientsession_mock.get,
"requests_post": get_clientsession_mock.post
}
yield agent_obj, all_mocks
@@ -745,7 +685,7 @@ async def test_generic_openai_agent(generic_openai_agent_fixture):
# invoke the conversation agent
conversation_id = "test-conversation"
result = await generic_openai_agent.async_process(ConversationInput(
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en"
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
))
# assert on results + check side effects
@@ -804,7 +744,7 @@ async def test_generic_openai_agent(generic_openai_agent_fixture):
# do another turn of the same conversation but the chat endpoint
result = await generic_openai_agent.async_process(ConversationInput(
"turn off the office lights", MagicMock(), conversation_id, None, "en"
"turn off the office lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
))
assert result.response.speech['plain']['speech'] == "I am saying something!"