mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-08 05:14:02 -05:00
Split backends into separate files and start implementing streaming + tool support
This commit is contained in:
@@ -4,7 +4,6 @@ from __future__ import annotations
|
||||
import logging
|
||||
from typing import Final
|
||||
|
||||
import homeassistant.components.conversation as ha_conversation
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import ATTR_ENTITY_ID, Platform
|
||||
from homeassistant.core import HomeAssistant
|
||||
@@ -30,8 +29,11 @@ from .const import (
|
||||
BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER,
|
||||
BACKEND_TYPE_OLLAMA,
|
||||
)
|
||||
from .conversation import LlamaCppAgent, GenericOpenAIAPIAgent, GenericOpenAIResponsesAPIAgent, \
|
||||
TextGenerationWebuiAgent, LlamaCppPythonAPIAgent, OllamaAPIAgent, LocalLLMAgent
|
||||
from custom_components.llama_conversation.conversation import LocalLLMAgent
|
||||
from custom_components.llama_conversation.backends.llamacpp import LlamaCppAgent
|
||||
from custom_components.llama_conversation.backends.generic_openai import GenericOpenAIAPIAgent, GenericOpenAIResponsesAPIAgent
|
||||
from custom_components.llama_conversation.backends.tailored_openai import TextGenerationWebuiAgent, LlamaCppPythonAPIAgent
|
||||
from custom_components.llama_conversation.backends.ollama import OllamaAPIAgent
|
||||
|
||||
type LocalLLMConfigEntry = ConfigEntry[LocalLLMAgent]
|
||||
|
||||
|
||||
261
custom_components/llama_conversation/backends/generic_openai.py
Normal file
261
custom_components/llama_conversation/backends/generic_openai.py
Normal file
@@ -0,0 +1,261 @@
|
||||
"""Defines the OpenAI API compatible agents"""
|
||||
from __future__ import annotations
|
||||
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import datetime
|
||||
import logging
|
||||
from typing import List, Dict, Tuple, Optional, Any
|
||||
|
||||
from homeassistant.components import conversation as conversation
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
|
||||
from custom_components.llama_conversation.utils import format_url
|
||||
from custom_components.llama_conversation.const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_MAX_TOKENS,
|
||||
CONF_TEMPERATURE,
|
||||
CONF_TOP_P,
|
||||
CONF_REQUEST_TIMEOUT,
|
||||
CONF_OPENAI_API_KEY,
|
||||
CONF_REMEMBER_CONVERSATION,
|
||||
CONF_REMEMBER_CONVERSATION_TIME_MINUTES,
|
||||
CONF_REMOTE_USE_CHAT_ENDPOINT,
|
||||
CONF_GENERIC_OPENAI_PATH,
|
||||
DEFAULT_MAX_TOKENS,
|
||||
DEFAULT_TEMPERATURE,
|
||||
DEFAULT_TOP_P,
|
||||
DEFAULT_REQUEST_TIMEOUT,
|
||||
DEFAULT_REMEMBER_CONVERSATION,
|
||||
DEFAULT_REMEMBER_CONVERSATION_TIME_MINUTES,
|
||||
DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
|
||||
DEFAULT_GENERIC_OPENAI_PATH,
|
||||
)
|
||||
from custom_components.llama_conversation.conversation import LocalLLMAgent, TextGenerationResult
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
class BaseOpenAICompatibleAPIAgent(LocalLLMAgent):
|
||||
api_host: str
|
||||
api_key: str
|
||||
model_name: str
|
||||
|
||||
async def _async_load_model(self, entry: ConfigEntry) -> None:
|
||||
self.api_host = format_url(
|
||||
hostname=entry.data[CONF_HOST],
|
||||
port=entry.data[CONF_PORT],
|
||||
ssl=entry.data[CONF_SSL],
|
||||
path=""
|
||||
)
|
||||
|
||||
self.api_key = entry.data.get(CONF_OPENAI_API_KEY)
|
||||
self.model_name = entry.data.get(CONF_CHAT_MODEL)
|
||||
|
||||
async def _async_generate_with_parameters(self, endpoint: str, additional_params: dict) -> TextGenerationResult:
|
||||
"""Generate a response using the OpenAI-compatible API"""
|
||||
|
||||
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
||||
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
|
||||
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
|
||||
timeout = self.entry.options.get(CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT)
|
||||
|
||||
request_params = {
|
||||
"model": self.model_name,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
}
|
||||
|
||||
request_params.update(additional_params)
|
||||
|
||||
headers = {}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
session = async_get_clientsession(self.hass)
|
||||
response = None
|
||||
try:
|
||||
async with session.post(
|
||||
f"{self.api_host}{endpoint}",
|
||||
json=request_params,
|
||||
timeout=timeout,
|
||||
headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
result = await response.json()
|
||||
except asyncio.TimeoutError:
|
||||
return TextGenerationResult(raise_error=True, error_msg="The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities.")
|
||||
except aiohttp.ClientError as err:
|
||||
_LOGGER.debug(f"Err was: {err}")
|
||||
_LOGGER.debug(f"Request was: {request_params}")
|
||||
_LOGGER.debug(f"Result was: {response}")
|
||||
return TextGenerationResult(raise_error=True, error_msg=f"Failed to communicate with the API! {err}")
|
||||
|
||||
_LOGGER.debug(result)
|
||||
|
||||
return self._extract_response(result)
|
||||
|
||||
def _extract_response(self, response_json: dict) -> TextGenerationResult:
|
||||
raise NotImplementedError("Subclasses must implement _extract_response()")
|
||||
|
||||
class GenericOpenAIAPIAgent(BaseOpenAICompatibleAPIAgent):
|
||||
"""Implements the OpenAPI-compatible text completion and chat completion API backends."""
|
||||
|
||||
def _chat_completion_params(self, conversation: List[Dict[str, str]]) -> Tuple[str, Dict[str, Any]]:
|
||||
request_params = {}
|
||||
api_base_path = self.entry.data.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
|
||||
|
||||
endpoint = f"/{api_base_path}/chat/completions"
|
||||
request_params["messages"] = [ { "role": x["role"], "content": x["message"] } for x in conversation ]
|
||||
|
||||
return endpoint, request_params
|
||||
|
||||
def _completion_params(self, conversation: List[Dict[str, str]]) -> Tuple[str, Dict[str, Any]]:
|
||||
request_params = {}
|
||||
api_base_path = self.entry.data.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
|
||||
|
||||
endpoint = f"/{api_base_path}/completions"
|
||||
request_params["prompt"] = self._format_prompt(conversation)
|
||||
|
||||
return endpoint, request_params
|
||||
|
||||
def _extract_response(self, response_json: dict) -> TextGenerationResult:
|
||||
choice = response_json["choices"][0]
|
||||
if response_json["object"] == "chat.completion":
|
||||
response_text = choice["message"]["content"]
|
||||
streamed = False
|
||||
elif response_json["object"] == "chat.completion.chunk":
|
||||
response_text = choice["message"]["content"]
|
||||
streamed = True
|
||||
else:
|
||||
response_text = choice["text"]
|
||||
streamed = False
|
||||
|
||||
if not streamed or streamed and choice["finish_reason"]:
|
||||
if choice["finish_reason"] == "length" or choice["finish_reason"] == "content_filter":
|
||||
_LOGGER.warning("Model response did not end on a stop token (unfinished sentence)")
|
||||
|
||||
return TextGenerationResult(
|
||||
response=response_text,
|
||||
stop_reason=choice["finish_reason"],
|
||||
response_streamed=streamed
|
||||
)
|
||||
|
||||
async def _async_generate(self, conversation: List[Dict[str, str]]) -> TextGenerationResult:
|
||||
use_chat_api = self.entry.options.get(CONF_REMOTE_USE_CHAT_ENDPOINT, DEFAULT_REMOTE_USE_CHAT_ENDPOINT)
|
||||
|
||||
if use_chat_api:
|
||||
endpoint, additional_params = self._chat_completion_params(conversation)
|
||||
else:
|
||||
endpoint, additional_params = self._completion_params(conversation)
|
||||
|
||||
result = await self._async_generate_with_parameters(endpoint, additional_params)
|
||||
|
||||
return result
|
||||
|
||||
class GenericOpenAIResponsesAPIAgent(BaseOpenAICompatibleAPIAgent):
|
||||
"""Implements the OpenAPI-compatible Responses API backend."""
|
||||
|
||||
_last_response_id: str | None = None
|
||||
_last_response_id_time: datetime.datetime = None
|
||||
|
||||
def _responses_params(self, conversation: List[Dict[str, str]]) -> Tuple[str, Dict[str, Any]]:
|
||||
request_params = {}
|
||||
api_base_path = self.entry.data.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
|
||||
|
||||
endpoint = f"/{api_base_path}/responses"
|
||||
request_params["input"] = conversation[-1]["message"] # last message in the conversation is the user input
|
||||
|
||||
# Assign previous_response_id if relevant
|
||||
if self._last_response_id and self.entry.options.get(CONF_REMEMBER_CONVERSATION, DEFAULT_REMEMBER_CONVERSATION):
|
||||
# If the last response was generated recently, use it as a context
|
||||
configured_memory_time: datetime.timedelta = datetime.timedelta(minutes=self.entry.options.get(CONF_REMEMBER_CONVERSATION_TIME_MINUTES, DEFAULT_REMEMBER_CONVERSATION_TIME_MINUTES))
|
||||
last_conversation_age: datetime.timedelta = datetime.datetime.now() - self._last_response_id_time
|
||||
_LOGGER.debug(f"Conversation ID age: {last_conversation_age}")
|
||||
if last_conversation_age < configured_memory_time:
|
||||
_LOGGER.debug(f"Using previous response ID {self._last_response_id} for context")
|
||||
request_params["previous_response_id"] = self._last_response_id
|
||||
else:
|
||||
_LOGGER.debug(f"Previous response ID {self._last_response_id} is too old, not using it for context")
|
||||
|
||||
return endpoint, request_params
|
||||
|
||||
def _validate_response_payload(self, response_json: dict) -> bool:
|
||||
"""
|
||||
Validate that the payload given matches the expected structure for the Responses API.
|
||||
|
||||
API ref: https://platform.openai.com/docs/api-reference/responses/object
|
||||
|
||||
Returns True or raises an error
|
||||
"""
|
||||
required_response_keys = ["object", "output", "status", "id"]
|
||||
missing_keys = [key for key in required_response_keys if key not in response_json]
|
||||
if missing_keys:
|
||||
raise ValueError(f"Response JSON is missing required keys: {', '.join(missing_keys)}")
|
||||
|
||||
if response_json["object"] != "response":
|
||||
raise ValueError(f"Response JSON object is not 'response', got {response_json['object']}")
|
||||
|
||||
if "error" in response_json and response_json["error"] is not None:
|
||||
error = response_json["error"]
|
||||
_LOGGER.error(f"Response received error payload.")
|
||||
if "message" not in error:
|
||||
raise ValueError("Response JSON error is missing 'message' key")
|
||||
raise ValueError(f"Response JSON error: {error['message']}")
|
||||
|
||||
return True
|
||||
|
||||
def _check_response_status(self, response_json: dict) -> None:
|
||||
"""
|
||||
Check the status of the response and logs a message if it is not 'completed'.
|
||||
|
||||
API ref: https://platform.openai.com/docs/api-reference/responses/object#responses_object-status
|
||||
"""
|
||||
if response_json["status"] != "completed":
|
||||
_LOGGER.warning(f"Response status is not 'completed', got {response_json['status']}. Details: {response_json.get('incomplete_details', 'No details provided')}")
|
||||
|
||||
def _extract_response(self, response_json: dict) -> TextGenerationResult:
|
||||
self._validate_response_payload(response_json)
|
||||
self._check_response_status(response_json)
|
||||
|
||||
outputs = response_json["output"]
|
||||
|
||||
if len(outputs) > 1:
|
||||
_LOGGER.warning("Received multiple outputs from the Responses API, returning the first one.")
|
||||
|
||||
output = outputs[0]
|
||||
|
||||
if not output["type"] == "message":
|
||||
raise NotImplementedError(f"Response output type is not 'message', got {output['type']}")
|
||||
|
||||
if len(output["content"]) > 1:
|
||||
_LOGGER.warning("Received multiple content items in the response output, returning the first one.")
|
||||
|
||||
content = output["content"][0]
|
||||
|
||||
output_type = content["type"]
|
||||
|
||||
to_return: str | None = None
|
||||
|
||||
if output_type == "refusal":
|
||||
_LOGGER.info("Received a refusal from the Responses API.")
|
||||
to_return = content["refusal"]
|
||||
elif output_type == "output_text":
|
||||
to_return = content["text"]
|
||||
else:
|
||||
raise ValueError(f"Response output content type is not expected, got {output_type}")
|
||||
|
||||
# Save the response_id and return the successful response.
|
||||
response_id = response_json["id"]
|
||||
self._last_response_id = response_id
|
||||
self._last_response_id_time = datetime.datetime.now()
|
||||
|
||||
return to_return
|
||||
|
||||
async def _async_generate(self, conversation: List[Dict[str, str]]) -> TextGenerationResult:
|
||||
"""Generate a response using the OpenAI-compatible Responses API"""
|
||||
|
||||
endpoint, additional_params = self._responses_params(conversation)
|
||||
return await self._async_generate_with_parameters(endpoint, additional_params)
|
||||
412
custom_components/llama_conversation/backends/llamacpp.py
Normal file
412
custom_components/llama_conversation/backends/llamacpp.py
Normal file
@@ -0,0 +1,412 @@
|
||||
"""Defines the llama cpp agent"""
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Callable, Generator, Optional, List, Dict, AsyncIterable
|
||||
|
||||
from homeassistant.components import conversation as conversation
|
||||
from homeassistant.components.conversation.const import DOMAIN as CONVERSATION_DOMAIN
|
||||
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.exceptions import ConfigEntryError, HomeAssistantError
|
||||
from homeassistant.helpers import llm
|
||||
from homeassistant.helpers.event import async_track_state_change, async_call_later
|
||||
|
||||
from custom_components.llama_conversation.utils import install_llama_cpp_python, validate_llama_cpp_python_installation
|
||||
from custom_components.llama_conversation.const import (
|
||||
CONF_MAX_TOKENS,
|
||||
CONF_PROMPT,
|
||||
CONF_TEMPERATURE,
|
||||
CONF_TOP_K,
|
||||
CONF_TOP_P,
|
||||
CONF_TYPICAL_P,
|
||||
CONF_MIN_P,
|
||||
CONF_DOWNLOADED_MODEL_FILE,
|
||||
CONF_ENABLE_FLASH_ATTENTION,
|
||||
CONF_USE_GBNF_GRAMMAR,
|
||||
CONF_GBNF_GRAMMAR_FILE,
|
||||
CONF_PROMPT_CACHING_ENABLED,
|
||||
CONF_PROMPT_CACHING_INTERVAL,
|
||||
CONF_CONTEXT_LENGTH,
|
||||
CONF_BATCH_SIZE,
|
||||
CONF_THREAD_COUNT,
|
||||
CONF_BATCH_THREAD_COUNT,
|
||||
DEFAULT_MAX_TOKENS,
|
||||
DEFAULT_PROMPT,
|
||||
DEFAULT_TEMPERATURE,
|
||||
DEFAULT_TOP_K,
|
||||
DEFAULT_TOP_P,
|
||||
DEFAULT_MIN_P,
|
||||
DEFAULT_TYPICAL_P,
|
||||
DEFAULT_ENABLE_FLASH_ATTENTION,
|
||||
DEFAULT_USE_GBNF_GRAMMAR,
|
||||
DEFAULT_GBNF_GRAMMAR_FILE,
|
||||
DEFAULT_PROMPT_CACHING_ENABLED,
|
||||
DEFAULT_PROMPT_CACHING_INTERVAL,
|
||||
DEFAULT_CONTEXT_LENGTH,
|
||||
DEFAULT_BATCH_SIZE,
|
||||
DEFAULT_THREAD_COUNT,
|
||||
DEFAULT_BATCH_THREAD_COUNT,
|
||||
)
|
||||
from custom_components.llama_conversation.conversation import LocalLLMAgent, TextGenerationResult
|
||||
|
||||
# make type checking work for llama-cpp-python without importing it directly at runtime
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from llama_cpp import Llama as LlamaType
|
||||
else:
|
||||
LlamaType = Any
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
class LlamaCppAgent(LocalLLMAgent):
|
||||
model_path: str
|
||||
llm: LlamaType
|
||||
grammar: Any
|
||||
llama_cpp_module: Any
|
||||
remove_prompt_caching_listener: Callable
|
||||
model_lock: threading.Lock
|
||||
last_cache_prime: float
|
||||
last_updated_entities: dict[str, float]
|
||||
cache_refresh_after_cooldown: bool
|
||||
loaded_model_settings: dict[str, Any]
|
||||
|
||||
_attr_supports_streaming = True
|
||||
|
||||
def _load_model(self, entry: ConfigEntry) -> None:
|
||||
self.model_path = entry.data.get(CONF_DOWNLOADED_MODEL_FILE)
|
||||
|
||||
_LOGGER.info(
|
||||
"Using model file '%s'", self.model_path
|
||||
)
|
||||
|
||||
if not self.model_path:
|
||||
raise Exception(f"Model was not found at '{self.model_path}'!")
|
||||
|
||||
validate_llama_cpp_python_installation()
|
||||
|
||||
# don't import it until now because the wheel is installed by config_flow.py
|
||||
try:
|
||||
self.llama_cpp_module = importlib.import_module("llama_cpp")
|
||||
except ModuleNotFoundError:
|
||||
# attempt to re-install llama-cpp-python if it was uninstalled for some reason
|
||||
install_result = install_llama_cpp_python(self.hass.config.config_dir)
|
||||
if not install_result == True:
|
||||
raise ConfigEntryError("llama-cpp-python was not installed on startup and re-installing it led to an error!")
|
||||
|
||||
validate_llama_cpp_python_installation()
|
||||
self.llama_cpp_module = importlib.import_module("llama_cpp")
|
||||
|
||||
Llama = getattr(self.llama_cpp_module, "Llama")
|
||||
|
||||
_LOGGER.debug(f"Loading model '{self.model_path}'...")
|
||||
self.loaded_model_settings = {}
|
||||
self.loaded_model_settings[CONF_CONTEXT_LENGTH] = entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
|
||||
self.loaded_model_settings[CONF_BATCH_SIZE] = entry.options.get(CONF_BATCH_SIZE, DEFAULT_BATCH_SIZE)
|
||||
self.loaded_model_settings[CONF_THREAD_COUNT] = entry.options.get(CONF_THREAD_COUNT, DEFAULT_THREAD_COUNT)
|
||||
self.loaded_model_settings[CONF_BATCH_THREAD_COUNT] = entry.options.get(CONF_BATCH_THREAD_COUNT, DEFAULT_BATCH_THREAD_COUNT)
|
||||
self.loaded_model_settings[CONF_ENABLE_FLASH_ATTENTION] = entry.options.get(CONF_ENABLE_FLASH_ATTENTION, DEFAULT_ENABLE_FLASH_ATTENTION)
|
||||
|
||||
self.llm = Llama(
|
||||
model_path=self.model_path,
|
||||
n_ctx=int(self.loaded_model_settings[CONF_CONTEXT_LENGTH]),
|
||||
n_batch=int(self.loaded_model_settings[CONF_BATCH_SIZE]),
|
||||
n_threads=int(self.loaded_model_settings[CONF_THREAD_COUNT]),
|
||||
n_threads_batch=int(self.loaded_model_settings[CONF_BATCH_THREAD_COUNT]),
|
||||
flash_attn=self.loaded_model_settings[CONF_ENABLE_FLASH_ATTENTION],
|
||||
)
|
||||
_LOGGER.debug("Model loaded")
|
||||
|
||||
self.grammar = None
|
||||
if entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR):
|
||||
self._load_grammar(entry.options.get(CONF_GBNF_GRAMMAR_FILE, DEFAULT_GBNF_GRAMMAR_FILE))
|
||||
|
||||
|
||||
# TODO: check about disk caching
|
||||
# self.llm.set_cache(self.llama_cpp_module.LlamaDiskCache(
|
||||
# capacity_bytes=(512 * 10e8),
|
||||
# cache_dir=os.path.join(self.hass.config.media_dirs.get("local", self.hass.config.path("media")), "kv_cache")
|
||||
# ))
|
||||
|
||||
self.remove_prompt_caching_listener = None
|
||||
self.last_cache_prime = None
|
||||
self.last_updated_entities = {}
|
||||
self.cache_refresh_after_cooldown = False
|
||||
self.model_lock = threading.Lock()
|
||||
|
||||
self.loaded_model_settings[CONF_PROMPT_CACHING_ENABLED] = entry.options.get(CONF_PROMPT_CACHING_ENABLED, DEFAULT_PROMPT_CACHING_ENABLED)
|
||||
if self.loaded_model_settings[CONF_PROMPT_CACHING_ENABLED]:
|
||||
@callback
|
||||
async def enable_caching_after_startup(_now) -> None:
|
||||
self._set_prompt_caching(enabled=True)
|
||||
await self._async_cache_prompt(None, None, None)
|
||||
async_call_later(self.hass, 5.0, enable_caching_after_startup)
|
||||
|
||||
def _load_grammar(self, filename: str):
|
||||
LlamaGrammar = getattr(self.llama_cpp_module, "LlamaGrammar")
|
||||
_LOGGER.debug(f"Loading grammar {filename}...")
|
||||
try:
|
||||
with open(os.path.join(os.path.dirname(__file__), filename)) as f:
|
||||
grammar_str = "".join(f.readlines())
|
||||
self.grammar = LlamaGrammar.from_string(grammar_str)
|
||||
self.loaded_model_settings[CONF_GBNF_GRAMMAR_FILE] = filename
|
||||
_LOGGER.debug("Loaded grammar")
|
||||
except Exception:
|
||||
_LOGGER.exception("Failed to load grammar!")
|
||||
self.grammar = None
|
||||
|
||||
def _update_options(self):
|
||||
LocalLLMAgent._update_options(self)
|
||||
|
||||
model_reloaded = False
|
||||
if self.loaded_model_settings[CONF_CONTEXT_LENGTH] != self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH) or \
|
||||
self.loaded_model_settings[CONF_BATCH_SIZE] != self.entry.options.get(CONF_BATCH_SIZE, DEFAULT_BATCH_SIZE) or \
|
||||
self.loaded_model_settings[CONF_THREAD_COUNT] != self.entry.options.get(CONF_THREAD_COUNT, DEFAULT_THREAD_COUNT) or \
|
||||
self.loaded_model_settings[CONF_BATCH_THREAD_COUNT] != self.entry.options.get(CONF_BATCH_THREAD_COUNT, DEFAULT_BATCH_THREAD_COUNT) or \
|
||||
self.loaded_model_settings[CONF_ENABLE_FLASH_ATTENTION] != self.entry.options.get(CONF_ENABLE_FLASH_ATTENTION, DEFAULT_ENABLE_FLASH_ATTENTION):
|
||||
|
||||
_LOGGER.debug(f"Reloading model '{self.model_path}'...")
|
||||
self.loaded_model_settings[CONF_CONTEXT_LENGTH] = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
|
||||
self.loaded_model_settings[CONF_BATCH_SIZE] = self.entry.options.get(CONF_BATCH_SIZE, DEFAULT_BATCH_SIZE)
|
||||
self.loaded_model_settings[CONF_THREAD_COUNT] = self.entry.options.get(CONF_THREAD_COUNT, DEFAULT_THREAD_COUNT)
|
||||
self.loaded_model_settings[CONF_BATCH_THREAD_COUNT] = self.entry.options.get(CONF_BATCH_THREAD_COUNT, DEFAULT_BATCH_THREAD_COUNT)
|
||||
self.loaded_model_settings[CONF_ENABLE_FLASH_ATTENTION] = self.entry.options.get(CONF_ENABLE_FLASH_ATTENTION, DEFAULT_ENABLE_FLASH_ATTENTION)
|
||||
|
||||
Llama = getattr(self.llama_cpp_module, "Llama")
|
||||
self.llm = Llama(
|
||||
model_path=self.model_path,
|
||||
n_ctx=int(self.loaded_model_settings[CONF_CONTEXT_LENGTH]),
|
||||
n_batch=int(self.loaded_model_settings[CONF_BATCH_SIZE]),
|
||||
n_threads=int(self.loaded_model_settings[CONF_THREAD_COUNT]),
|
||||
n_threads_batch=int(self.loaded_model_settings[CONF_BATCH_THREAD_COUNT]),
|
||||
flash_attn=self.loaded_model_settings[CONF_ENABLE_FLASH_ATTENTION],
|
||||
)
|
||||
_LOGGER.debug("Model loaded")
|
||||
model_reloaded = True
|
||||
|
||||
if self.entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR):
|
||||
current_grammar = self.entry.options.get(CONF_GBNF_GRAMMAR_FILE, DEFAULT_GBNF_GRAMMAR_FILE)
|
||||
if not self.grammar or self.loaded_model_settings[CONF_GBNF_GRAMMAR_FILE] != current_grammar:
|
||||
self._load_grammar(current_grammar)
|
||||
else:
|
||||
self.grammar = None
|
||||
|
||||
if self.entry.options.get(CONF_PROMPT_CACHING_ENABLED, DEFAULT_PROMPT_CACHING_ENABLED):
|
||||
self._set_prompt_caching(enabled=True)
|
||||
|
||||
if self.loaded_model_settings[CONF_PROMPT_CACHING_ENABLED] != self.entry.options.get(CONF_PROMPT_CACHING_ENABLED, DEFAULT_PROMPT_CACHING_ENABLED) or \
|
||||
model_reloaded:
|
||||
self.loaded_model_settings[CONF_PROMPT_CACHING_ENABLED] = self.entry.options.get(CONF_PROMPT_CACHING_ENABLED, DEFAULT_PROMPT_CACHING_ENABLED)
|
||||
|
||||
async def cache_current_prompt(_now):
|
||||
await self._async_cache_prompt(None, None, None)
|
||||
async_call_later(self.hass, 1.0, cache_current_prompt)
|
||||
else:
|
||||
self._set_prompt_caching(enabled=False)
|
||||
|
||||
def _async_get_exposed_entities(self) -> tuple[dict[str, str], list[str]]:
|
||||
"""Takes the super class function results and sorts the entities with the recently updated at the end"""
|
||||
entities, domains = LocalLLMAgent._async_get_exposed_entities(self)
|
||||
|
||||
# ignore sorting if prompt caching is disabled
|
||||
if not self.entry.options.get(CONF_PROMPT_CACHING_ENABLED, DEFAULT_PROMPT_CACHING_ENABLED):
|
||||
return entities, domains
|
||||
|
||||
entity_order = { name: None for name in entities.keys() }
|
||||
entity_order.update(self.last_updated_entities)
|
||||
|
||||
def sort_key(item):
|
||||
item_name, last_updated = item
|
||||
# Handle cases where last updated is None
|
||||
if last_updated is None:
|
||||
return (False, '', item_name)
|
||||
else:
|
||||
return (True, last_updated, '')
|
||||
|
||||
# Sort the items based on the sort_key function
|
||||
sorted_items = sorted(list(entity_order.items()), key=sort_key)
|
||||
|
||||
_LOGGER.debug(f"sorted_items: {sorted_items}")
|
||||
|
||||
sorted_entities = {}
|
||||
for item_name, _ in sorted_items:
|
||||
sorted_entities[item_name] = entities[item_name]
|
||||
|
||||
return sorted_entities, domains
|
||||
|
||||
def _set_prompt_caching(self, *, enabled=True):
|
||||
if enabled and not self.remove_prompt_caching_listener:
|
||||
_LOGGER.info("enabling prompt caching...")
|
||||
|
||||
entity_ids = [
|
||||
state.entity_id for state in self.hass.states.async_all() \
|
||||
if async_should_expose(self.hass, CONVERSATION_DOMAIN, state.entity_id)
|
||||
]
|
||||
|
||||
_LOGGER.debug(f"watching entities: {entity_ids}")
|
||||
|
||||
self.remove_prompt_caching_listener = async_track_state_change(self.hass, entity_ids, self._async_cache_prompt)
|
||||
|
||||
elif not enabled and self.remove_prompt_caching_listener:
|
||||
_LOGGER.info("disabling prompt caching...")
|
||||
self.remove_prompt_caching_listener()
|
||||
|
||||
@callback
|
||||
async def _async_cache_prompt(self, entity, old_state, new_state):
|
||||
refresh_start = time.time()
|
||||
|
||||
# track last update time so we can sort the context efficiently
|
||||
if entity:
|
||||
self.last_updated_entities[entity] = refresh_start
|
||||
|
||||
llm_api: llm.APIInstance | None = None
|
||||
if self.entry.options.get(CONF_LLM_HASS_API):
|
||||
try:
|
||||
llm_api = await llm.async_get_api(
|
||||
self.hass, self.entry.options[CONF_LLM_HASS_API]
|
||||
)
|
||||
except HomeAssistantError:
|
||||
_LOGGER.exception("Failed to get LLM API when caching prompt!")
|
||||
return
|
||||
|
||||
_LOGGER.debug(f"refreshing cached prompt because {entity} changed...")
|
||||
await self.hass.async_add_executor_job(self._cache_prompt, llm_api)
|
||||
|
||||
refresh_end = time.time()
|
||||
_LOGGER.debug(f"cache refresh took {(refresh_end - refresh_start):.2f} sec")
|
||||
|
||||
def _cache_prompt(self, llm_api: llm.APIInstance | None) -> None:
|
||||
# if a refresh is already scheduled then exit
|
||||
if self.cache_refresh_after_cooldown:
|
||||
return
|
||||
|
||||
# if we are inside the cooldown period, request a refresh and exit
|
||||
current_time = time.time()
|
||||
fastest_prime_interval = self.entry.options.get(CONF_PROMPT_CACHING_INTERVAL, DEFAULT_PROMPT_CACHING_INTERVAL)
|
||||
if self.last_cache_prime and current_time - self.last_cache_prime < fastest_prime_interval:
|
||||
self.cache_refresh_after_cooldown = True
|
||||
return
|
||||
|
||||
# try to acquire the lock, if we are still running for some reason, request a refresh and exit
|
||||
lock_acquired = self.model_lock.acquire(False)
|
||||
if not lock_acquired:
|
||||
self.cache_refresh_after_cooldown = True
|
||||
return
|
||||
|
||||
try:
|
||||
raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
|
||||
prompt = self._format_prompt([
|
||||
{ "role": "system", "message": self._generate_system_prompt(raw_prompt, llm_api)},
|
||||
{ "role": "user", "message": "" }
|
||||
], include_generation_prompt=False)
|
||||
|
||||
input_tokens = self.llm.tokenize(
|
||||
prompt.encode(), add_bos=False
|
||||
)
|
||||
|
||||
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
|
||||
top_k = int(self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K))
|
||||
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
|
||||
grammar = self.grammar if self.entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR) else None
|
||||
|
||||
_LOGGER.debug(f"Options: {self.entry.options}")
|
||||
|
||||
_LOGGER.debug(f"Processing {len(input_tokens)} input tokens...")
|
||||
|
||||
# grab just one token. should prime the kv cache with the system prompt
|
||||
next(self.llm.generate(
|
||||
input_tokens,
|
||||
temp=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
grammar=grammar
|
||||
))
|
||||
|
||||
self.last_cache_prime = time.time()
|
||||
finally:
|
||||
self.model_lock.release()
|
||||
|
||||
|
||||
# schedule a refresh using async_call_later
|
||||
# if the flag is set after the delay then we do another refresh
|
||||
|
||||
@callback
|
||||
async def refresh_if_requested(_now):
|
||||
if self.cache_refresh_after_cooldown:
|
||||
self.cache_refresh_after_cooldown = False
|
||||
|
||||
refresh_start = time.time()
|
||||
_LOGGER.debug(f"refreshing cached prompt after cooldown...")
|
||||
await self.hass.async_add_executor_job(self._cache_prompt)
|
||||
|
||||
refresh_end = time.time()
|
||||
_LOGGER.debug(f"cache refresh took {(refresh_end - refresh_start):.2f} sec")
|
||||
|
||||
refresh_delay = self.entry.options.get(CONF_PROMPT_CACHING_INTERVAL, DEFAULT_PROMPT_CACHING_INTERVAL)
|
||||
async_call_later(self.hass, float(refresh_delay), refresh_if_requested)
|
||||
|
||||
|
||||
def _generate_stream(self, conversation: List[Dict[str, str]], user_input: conversation.ConversationInput, chat_log: conversation.ChatLog) -> TextGenerationResult:
|
||||
prompt = self._format_prompt(conversation)
|
||||
|
||||
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
||||
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
|
||||
top_k = int(self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K))
|
||||
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
|
||||
min_p = self.entry.options.get(CONF_MIN_P, DEFAULT_MIN_P)
|
||||
typical_p = self.entry.options.get(CONF_TYPICAL_P, DEFAULT_TYPICAL_P)
|
||||
|
||||
_LOGGER.debug(f"Options: {self.entry.options}")
|
||||
|
||||
with self.model_lock:
|
||||
input_tokens = self.llm.tokenize(
|
||||
prompt.encode(), add_bos=False
|
||||
)
|
||||
|
||||
context_len = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
|
||||
if len(input_tokens) >= context_len:
|
||||
num_entities = len(self._async_get_exposed_entities()[0])
|
||||
context_size = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
|
||||
self._warn_context_size()
|
||||
raise Exception(f"The model failed to produce a result because too many devices are exposed ({num_entities} devices) for the context size ({context_size} tokens)!")
|
||||
if len(input_tokens) + max_tokens >= context_len:
|
||||
self._warn_context_size()
|
||||
|
||||
_LOGGER.debug(f"Processing {len(input_tokens)} input tokens...")
|
||||
output_tokens = self.llm.generate(
|
||||
input_tokens,
|
||||
temp=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
min_p=min_p,
|
||||
typical_p=typical_p,
|
||||
grammar=self.grammar
|
||||
)
|
||||
|
||||
result_tokens = []
|
||||
async def async_iterator():
|
||||
num_tokens = 0
|
||||
for token in output_tokens:
|
||||
result_tokens.append(token)
|
||||
yield TextGenerationResult(response=self.llm.detokenize([token]).decode(), response_streamed=True)
|
||||
|
||||
if token == self.llm.token_eos():
|
||||
break
|
||||
|
||||
if len(result_tokens) >= max_tokens:
|
||||
break
|
||||
|
||||
num_tokens += 1
|
||||
|
||||
self._transform_result_stream(async_iterator(), user_input=user_input, chat_log=chat_log)
|
||||
|
||||
response = TextGenerationResult(
|
||||
response=self.llm.detokenize(result_tokens).decode(),
|
||||
response_streamed=True,
|
||||
)
|
||||
198
custom_components/llama_conversation/backends/ollama.py
Normal file
198
custom_components/llama_conversation/backends/ollama.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""Defines the ollama compatible agent"""
|
||||
from __future__ import annotations
|
||||
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Tuple, Dict, List, Any
|
||||
|
||||
from homeassistant.components import conversation as conversation
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL
|
||||
from homeassistant.exceptions import ConfigEntryNotReady
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
|
||||
from custom_components.llama_conversation.utils import format_url
|
||||
from custom_components.llama_conversation.const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_MAX_TOKENS,
|
||||
CONF_TEMPERATURE,
|
||||
CONF_TOP_K,
|
||||
CONF_TOP_P,
|
||||
CONF_TYPICAL_P,
|
||||
CONF_REQUEST_TIMEOUT,
|
||||
CONF_OPENAI_API_KEY,
|
||||
CONF_REMOTE_USE_CHAT_ENDPOINT,
|
||||
CONF_OLLAMA_KEEP_ALIVE_MIN,
|
||||
CONF_OLLAMA_JSON_MODE,
|
||||
CONF_CONTEXT_LENGTH,
|
||||
DEFAULT_MAX_TOKENS,
|
||||
DEFAULT_TEMPERATURE,
|
||||
DEFAULT_TOP_K,
|
||||
DEFAULT_TOP_P,
|
||||
DEFAULT_TYPICAL_P,
|
||||
DEFAULT_REQUEST_TIMEOUT,
|
||||
DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
|
||||
DEFAULT_OLLAMA_KEEP_ALIVE_MIN,
|
||||
DEFAULT_OLLAMA_JSON_MODE,
|
||||
DEFAULT_CONTEXT_LENGTH,
|
||||
)
|
||||
|
||||
from custom_components.llama_conversation.conversation import LocalLLMAgent, TextGenerationResult
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
class OllamaAPIAgent(LocalLLMAgent):
|
||||
api_host: str
|
||||
api_key: Optional[str]
|
||||
model_name: Optional[str]
|
||||
|
||||
async def _async_load_model(self, entry: ConfigEntry) -> None:
|
||||
self.api_host = format_url(
|
||||
hostname=entry.data[CONF_HOST],
|
||||
port=entry.data[CONF_PORT],
|
||||
ssl=entry.data[CONF_SSL],
|
||||
path=""
|
||||
)
|
||||
self.api_key = entry.data.get(CONF_OPENAI_API_KEY)
|
||||
self.model_name = entry.data.get(CONF_CHAT_MODEL)
|
||||
|
||||
# ollama handles loading for us so just make sure the model is available
|
||||
try:
|
||||
headers = {}
|
||||
session = async_get_clientsession(self.hass)
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
async with session.get(
|
||||
f"{self.api_host}/api/tags",
|
||||
headers=headers,
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
currently_downloaded_result = await response.json()
|
||||
|
||||
except Exception as ex:
|
||||
_LOGGER.debug("Connection error was: %s", repr(ex))
|
||||
raise ConfigEntryNotReady("There was a problem connecting to the remote server") from ex
|
||||
|
||||
model_names = [ x["name"] for x in currently_downloaded_result["models"] ]
|
||||
if ":" in self.model_name:
|
||||
if not any([ name == self.model_name for name in model_names]):
|
||||
raise ConfigEntryNotReady(f"Ollama server does not have the provided model: {self.model_name}")
|
||||
elif not any([ name.split(":")[0] == self.model_name for name in model_names ]):
|
||||
raise ConfigEntryNotReady(f"Ollama server does not have the provided model: {self.model_name}")
|
||||
|
||||
def _chat_completion_params(self, conversation: List[Dict[str, str]]) -> Tuple[str, Dict]:
|
||||
request_params = {}
|
||||
|
||||
endpoint = "/api/chat"
|
||||
request_params["messages"] = [ { "role": x["role"], "content": x["message"] } for x in conversation ]
|
||||
|
||||
return endpoint, request_params
|
||||
|
||||
def _completion_params(self, conversation: List[Dict[str, str]]) -> Tuple[str, Dict[str, Any]]:
|
||||
request_params = {}
|
||||
|
||||
endpoint = "/api/generate"
|
||||
request_params["prompt"] = self._format_prompt(conversation)
|
||||
request_params["raw"] = True # ignore prompt template
|
||||
|
||||
return endpoint, request_params
|
||||
|
||||
def _extract_response(self, response_json: Dict) -> TextGenerationResult:
|
||||
# TODO: this doesn't work because ollama caches prompts and doesn't always return the full prompt length
|
||||
# context_len = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
|
||||
# max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
||||
# if response_json["prompt_eval_count"] + max_tokens > context_len:
|
||||
# self._warn_context_size()
|
||||
|
||||
if "response" in response_json:
|
||||
response = response_json["response"]
|
||||
tool_calls = None
|
||||
stop_reason = None
|
||||
if response_json["done"] not in ["true", True]:
|
||||
_LOGGER.warning("Model response did not end on a stop token (unfinished sentence)")
|
||||
else:
|
||||
response = response_json["message"]["content"]
|
||||
tool_calls = response_json["message"].get("tool_calls")
|
||||
stop_reason = response_json.get("done_reason")
|
||||
|
||||
return TextGenerationResult(
|
||||
response=response, tool_calls=tool_calls, stop_reason=stop_reason, response_streamed=True
|
||||
)
|
||||
|
||||
async def _async_generate(self, conversation: List[Dict[str, str]]) -> TextGenerationResult:
|
||||
context_length = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
|
||||
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
||||
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
|
||||
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
|
||||
top_k = self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K)
|
||||
typical_p = self.entry.options.get(CONF_TYPICAL_P, DEFAULT_TYPICAL_P)
|
||||
timeout = self.entry.options.get(CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT)
|
||||
keep_alive = self.entry.options.get(CONF_OLLAMA_KEEP_ALIVE_MIN, DEFAULT_OLLAMA_KEEP_ALIVE_MIN)
|
||||
use_chat_api = self.entry.options.get(CONF_REMOTE_USE_CHAT_ENDPOINT, DEFAULT_REMOTE_USE_CHAT_ENDPOINT)
|
||||
json_mode = self.entry.options.get(CONF_OLLAMA_JSON_MODE, DEFAULT_OLLAMA_JSON_MODE)
|
||||
|
||||
request_params = {
|
||||
"model": self.model_name,
|
||||
"stream": True,
|
||||
"keep_alive": f"{keep_alive}m", # prevent ollama from unloading the model
|
||||
"options": {
|
||||
"num_ctx": context_length,
|
||||
"top_p": top_p,
|
||||
"top_k": top_k,
|
||||
"typical_p": typical_p,
|
||||
"temperature": temperature,
|
||||
"num_preDict": max_tokens,
|
||||
}
|
||||
}
|
||||
|
||||
if json_mode:
|
||||
request_params["format"] = "json"
|
||||
|
||||
if use_chat_api:
|
||||
endpoint, additional_params = self._chat_completion_params(conversation)
|
||||
else:
|
||||
endpoint, additional_params = self._completion_params(conversation)
|
||||
|
||||
request_params.update(additional_params)
|
||||
|
||||
headers = {}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
session = async_get_clientsession(self.hass)
|
||||
response = None
|
||||
result = TextGenerationResult(
|
||||
response="", response_streamed=True
|
||||
)
|
||||
try:
|
||||
async with session.post(
|
||||
f"{self.api_host}{endpoint}",
|
||||
json=request_params,
|
||||
timeout=timeout,
|
||||
headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
while True:
|
||||
chunk = await response.content.readline()
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
parsed_chunk = self._extract_response(json.loads(chunk))
|
||||
result.response += parsed_chunk.response
|
||||
result.stop_reason = parsed_chunk.stop_reason
|
||||
result.tool_calls = parsed_chunk.tool_calls
|
||||
except asyncio.TimeoutError:
|
||||
return TextGenerationResult(raise_error=True, error_msg="The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities.")
|
||||
except aiohttp.ClientError as err:
|
||||
_LOGGER.debug(f"Err was: {err}")
|
||||
_LOGGER.debug(f"Request was: {request_params}")
|
||||
_LOGGER.debug(f"Result was: {response}")
|
||||
return TextGenerationResult(raise_error=True, error_msg=f"Failed to communicate with the API! {err}")
|
||||
|
||||
_LOGGER.debug(result)
|
||||
|
||||
return result
|
||||
171
custom_components/llama_conversation/backends/tailored_openai.py
Normal file
171
custom_components/llama_conversation/backends/tailored_openai.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""Defines the various openai-like agents"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional, Tuple, Dict, List, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.exceptions import ConfigEntryNotReady
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
|
||||
from custom_components.llama_conversation.const import (
|
||||
CONF_MAX_TOKENS,
|
||||
CONF_TOP_K,
|
||||
CONF_TYPICAL_P,
|
||||
CONF_MIN_P,
|
||||
CONF_PROMPT_TEMPLATE,
|
||||
CONF_USE_GBNF_GRAMMAR,
|
||||
CONF_TEXT_GEN_WEBUI_PRESET,
|
||||
CONF_TEXT_GEN_WEBUI_ADMIN_KEY,
|
||||
CONF_TEXT_GEN_WEBUI_CHAT_MODE,
|
||||
CONF_CONTEXT_LENGTH,
|
||||
DEFAULT_MAX_TOKENS,
|
||||
DEFAULT_TOP_K,
|
||||
DEFAULT_MIN_P,
|
||||
DEFAULT_TYPICAL_P,
|
||||
DEFAULT_PROMPT_TEMPLATE,
|
||||
DEFAULT_USE_GBNF_GRAMMAR,
|
||||
DEFAULT_GBNF_GRAMMAR_FILE,
|
||||
DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
|
||||
DEFAULT_CONTEXT_LENGTH,
|
||||
TEXT_GEN_WEBUI_CHAT_MODE_CHAT,
|
||||
TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT,
|
||||
TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT,
|
||||
)
|
||||
from custom_components.llama_conversation.conversation import TextGenerationResult
|
||||
from custom_components.llama_conversation.backends.generic_openai import GenericOpenAIAPIAgent
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
class TextGenerationWebuiAgent(GenericOpenAIAPIAgent):
|
||||
admin_key: str
|
||||
|
||||
async def _async_load_model(self, entry: ConfigEntry) -> None:
|
||||
await super()._async_load_model(entry)
|
||||
self.admin_key = entry.data.get(CONF_TEXT_GEN_WEBUI_ADMIN_KEY, self.api_key)
|
||||
|
||||
try:
|
||||
headers = {}
|
||||
session = async_get_clientsession(self.hass)
|
||||
|
||||
if self.admin_key:
|
||||
headers["Authorization"] = f"Bearer {self.admin_key}"
|
||||
|
||||
async with session.get(
|
||||
f"{self.api_host}/v1/internal/model/info",
|
||||
headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
currently_loaded_result = await response.json()
|
||||
|
||||
loaded_model = currently_loaded_result["model_name"]
|
||||
if loaded_model == self.model_name:
|
||||
_LOGGER.info(f"Model {self.model_name} is already loaded on the remote backend.")
|
||||
return
|
||||
else:
|
||||
_LOGGER.info(f"Model is not {self.model_name} loaded on the remote backend. Loading it now...")
|
||||
|
||||
async with session.post(
|
||||
f"{self.api_host}/v1/internal/model/load",
|
||||
json={
|
||||
"model_name": self.model_name,
|
||||
# TODO: expose arguments to the user in home assistant UI
|
||||
# "args": {},
|
||||
},
|
||||
headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
except Exception as ex:
|
||||
_LOGGER.debug("Connection error was: %s", repr(ex))
|
||||
raise ConfigEntryNotReady("There was a problem connecting to the remote server") from ex
|
||||
|
||||
def _chat_completion_params(self, conversation: List[Dict[str, str]]) -> Tuple[str, Dict]:
|
||||
preset = self.entry.options.get(CONF_TEXT_GEN_WEBUI_PRESET)
|
||||
chat_mode = self.entry.options.get(CONF_TEXT_GEN_WEBUI_CHAT_MODE, DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE)
|
||||
|
||||
endpoint, request_params = super()._chat_completion_params(conversation)
|
||||
|
||||
request_params["mode"] = chat_mode
|
||||
if chat_mode == TEXT_GEN_WEBUI_CHAT_MODE_CHAT or chat_mode == TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT:
|
||||
if preset:
|
||||
request_params["character"] = preset
|
||||
elif chat_mode == TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT:
|
||||
request_params["instruction_template"] = self.entry.options.get(CONF_PROMPT_TEMPLATE, DEFAULT_PROMPT_TEMPLATE)
|
||||
|
||||
request_params["truncation_length"] = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
|
||||
request_params["top_k"] = self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K)
|
||||
request_params["min_p"] = self.entry.options.get(CONF_MIN_P, DEFAULT_MIN_P)
|
||||
request_params["typical_p"] = self.entry.options.get(CONF_TYPICAL_P, DEFAULT_TYPICAL_P)
|
||||
|
||||
return endpoint, request_params
|
||||
|
||||
def _completion_params(self, conversation: List[Dict[str, str]]) -> Tuple[str, Dict[str, Any]]:
|
||||
preset = self.entry.options.get(CONF_TEXT_GEN_WEBUI_PRESET)
|
||||
|
||||
endpoint, request_params = super()._completion_params(conversation)
|
||||
|
||||
if preset:
|
||||
request_params["preset"] = preset
|
||||
|
||||
request_params["truncation_length"] = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
|
||||
request_params["top_k"] = self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K)
|
||||
request_params["min_p"] = self.entry.options.get(CONF_MIN_P, DEFAULT_MIN_P)
|
||||
request_params["typical_p"] = self.entry.options.get(CONF_TYPICAL_P, DEFAULT_TYPICAL_P)
|
||||
|
||||
return endpoint, request_params
|
||||
|
||||
def _extract_response(self, response_json: dict) -> TextGenerationResult:
|
||||
choices = response_json["choices"]
|
||||
if choices[0]["finish_reason"] != "stop":
|
||||
_LOGGER.warning("Model response did not end on a stop token (unfinished sentence)")
|
||||
|
||||
context_len = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
|
||||
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
||||
if response_json["usage"]["prompt_tokens"] + max_tokens > context_len:
|
||||
self._warn_context_size()
|
||||
|
||||
# text-gen-webui has a typo where it is 'chat.completions' not 'chat.completion'
|
||||
if response_json["object"] == "chat.completions":
|
||||
return choices[0]["message"]["content"]
|
||||
else:
|
||||
return choices[0]["text"]
|
||||
|
||||
class LlamaCppPythonAPIAgent(GenericOpenAIAPIAgent):
|
||||
"""https://llama-cpp-python.readthedocs.io/en/latest/server/"""
|
||||
grammar: str
|
||||
|
||||
async def _async_load_model(self, entry: ConfigEntry):
|
||||
await super()._async_load_model(entry)
|
||||
|
||||
return await self.hass.async_add_executor_job(
|
||||
self._load_model, entry
|
||||
)
|
||||
|
||||
def _load_model(self, entry: ConfigEntry):
|
||||
with open(os.path.join(os.path.dirname(__file__), DEFAULT_GBNF_GRAMMAR_FILE)) as f:
|
||||
self.grammar = "".join(f.readlines())
|
||||
|
||||
def _chat_completion_params(self, conversation: List[Dict[str, str]]) -> Tuple[str, Dict[str, Any]]:
|
||||
top_k = int(self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K))
|
||||
endpoint, request_params = super()._chat_completion_params(conversation)
|
||||
|
||||
request_params["top_k"] = top_k
|
||||
|
||||
if self.entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR):
|
||||
request_params["grammar"] = self.grammar
|
||||
|
||||
return endpoint, request_params
|
||||
|
||||
def _completion_params(self, conversation: List[Dict[str, str]]) -> Tuple[str, Dict[str, Any]]:
|
||||
top_k = int(self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K))
|
||||
endpoint, request_params = super()._completion_params(conversation)
|
||||
|
||||
request_params["top_k"] = top_k
|
||||
|
||||
if self.entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR):
|
||||
request_params["grammar"] = self.grammar
|
||||
|
||||
return endpoint, request_params
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,9 +1,9 @@
|
||||
# types from Home Assistant
|
||||
homeassistant>=2024.6.1
|
||||
homeassistant>=2024.7.0
|
||||
hassil
|
||||
home-assistant-intents
|
||||
|
||||
# testing requirements
|
||||
pytest
|
||||
pytest-asyncio
|
||||
pytest-homeassistant-custom-component
|
||||
pytest-homeassistant-custom-component==0.13.260
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
-r requirements.txt
|
||||
-r requirements-dev.txt
|
||||
-r custom_components/requirements.txt
|
||||
-r custom_components/requirements-dev.txt
|
||||
-r data/requirements.txt
|
||||
@@ -4,7 +4,10 @@ import pytest
|
||||
import jinja2
|
||||
from unittest.mock import patch, MagicMock, PropertyMock, AsyncMock, ANY
|
||||
|
||||
from custom_components.llama_conversation.conversation import LlamaCppAgent, OllamaAPIAgent, TextGenerationWebuiAgent, GenericOpenAIAPIAgent
|
||||
from custom_components.llama_conversation.backends.llamacpp import LlamaCppAgent
|
||||
from custom_components.llama_conversation.backends.ollama import OllamaAPIAgent
|
||||
from custom_components.llama_conversation.backends.tailored_openai import TextGenerationWebuiAgent
|
||||
from custom_components.llama_conversation.backends.generic_openai import GenericOpenAIAPIAgent
|
||||
from custom_components.llama_conversation.const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_MAX_TOKENS,
|
||||
@@ -106,7 +109,8 @@ class MockConfigEntry:
|
||||
def __init__(self, entry_id='test_entry_id', data={}, options={}):
|
||||
self.entry_id = entry_id
|
||||
self.data = WarnDict(data)
|
||||
self.options = WarnDict(options)
|
||||
# Use a mutable dict for options in tests
|
||||
self.options = WarnDict(dict(options))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -138,14 +142,16 @@ def local_llama_agent_fixture(config_entry, hass, enable_custom_integrations):
|
||||
patch.object(LlamaCppAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \
|
||||
patch.object(APIInstance, 'async_call_tool') as call_tool_mock, \
|
||||
patch('homeassistant.helpers.template.Template') as template_mock, \
|
||||
patch('custom_components.llama_conversation.agent.importlib.import_module') as import_module_mock, \
|
||||
patch('custom_components.llama_conversation.agent.install_llama_cpp_python') as install_llama_cpp_python_mock:
|
||||
patch('custom_components.llama_conversation.backends.llamacpp.importlib.import_module') as import_module_mock, \
|
||||
patch('custom_components.llama_conversation.utils.importlib.import_module') as import_module_mock_2, \
|
||||
patch('custom_components.llama_conversation.utils.install_llama_cpp_python') as install_llama_cpp_python_mock:
|
||||
|
||||
entry_mock.return_value = config_entry
|
||||
llama_instance_mock = MagicMock()
|
||||
llama_class_mock = MagicMock()
|
||||
llama_class_mock.return_value = llama_instance_mock
|
||||
import_module_mock.return_value = MagicMock(Llama=llama_class_mock)
|
||||
import_module_mock_2.return_value = MagicMock(Llama=llama_class_mock)
|
||||
install_llama_cpp_python_mock.return_value = True
|
||||
get_exposed_entities_mock.return_value = (
|
||||
{
|
||||
@@ -194,7 +200,7 @@ async def test_local_llama_agent(local_llama_agent_fixture):
|
||||
# invoke the conversation agent
|
||||
conversation_id = "test-conversation"
|
||||
result = await local_llama_agent.async_process(ConversationInput(
|
||||
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en"
|
||||
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
|
||||
))
|
||||
|
||||
# assert on results + check side effects
|
||||
@@ -249,7 +255,7 @@ async def test_local_llama_agent(local_llama_agent_fixture):
|
||||
|
||||
# do another turn of the same conversation
|
||||
result = await local_llama_agent.async_process(ConversationInput(
|
||||
"turn off the office lights", MagicMock(), conversation_id, None, "en"
|
||||
"turn off the office lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
|
||||
))
|
||||
|
||||
assert result.response.speech['plain']['speech'] == "I am saying something!"
|
||||
@@ -272,8 +278,7 @@ def ollama_agent_fixture(config_entry, hass, enable_custom_integrations):
|
||||
patch.object(OllamaAPIAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \
|
||||
patch.object(APIInstance, 'async_call_tool') as call_tool_mock, \
|
||||
patch('homeassistant.helpers.template.Template') as template_mock, \
|
||||
patch('custom_components.llama_conversation.agent.requests.get') as requests_get_mock, \
|
||||
patch('custom_components.llama_conversation.agent.requests.post') as requests_post_mock:
|
||||
patch('custom_components.llama_conversation.backends.ollama.async_get_clientsession') as get_clientsession:
|
||||
|
||||
entry_mock.return_value = config_entry
|
||||
get_exposed_entities_mock.return_value = (
|
||||
@@ -288,7 +293,7 @@ def ollama_agent_fixture(config_entry, hass, enable_custom_integrations):
|
||||
|
||||
response_mock = MagicMock()
|
||||
response_mock.json.return_value = { "models": [ {"name": config_entry.data[CONF_CHAT_MODEL] }] }
|
||||
requests_get_mock.return_value = response_mock
|
||||
get_clientsession.get.return_value = response_mock
|
||||
|
||||
call_tool_mock.return_value = {"result": "success"}
|
||||
|
||||
@@ -298,8 +303,8 @@ def ollama_agent_fixture(config_entry, hass, enable_custom_integrations):
|
||||
)
|
||||
|
||||
all_mocks = {
|
||||
"requests_get": requests_get_mock,
|
||||
"requests_post": requests_post_mock
|
||||
"requests_get": get_clientsession.get,
|
||||
"requests_post": get_clientsession.post
|
||||
}
|
||||
|
||||
yield agent_obj, all_mocks
|
||||
@@ -339,7 +344,7 @@ async def test_ollama_agent(ollama_agent_fixture):
|
||||
# invoke the conversation agent
|
||||
conversation_id = "test-conversation"
|
||||
result = await ollama_agent.async_process(ConversationInput(
|
||||
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en"
|
||||
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
|
||||
))
|
||||
|
||||
# assert on results + check side effects
|
||||
@@ -384,7 +389,7 @@ async def test_ollama_agent(ollama_agent_fixture):
|
||||
|
||||
# do another turn of the same conversation
|
||||
result = await ollama_agent.async_process(ConversationInput(
|
||||
"turn off the office lights", MagicMock(), conversation_id, None, "en"
|
||||
"turn off the office lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
|
||||
))
|
||||
|
||||
assert result.response.speech['plain']['speech'] == "I am saying something!"
|
||||
@@ -418,8 +423,7 @@ def text_generation_webui_agent_fixture(config_entry, hass, enable_custom_integr
|
||||
patch.object(TextGenerationWebuiAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \
|
||||
patch.object(APIInstance, 'async_call_tool') as call_tool_mock, \
|
||||
patch('homeassistant.helpers.template.Template') as template_mock, \
|
||||
patch('custom_components.llama_conversation.agent.requests.get') as requests_get_mock, \
|
||||
patch('custom_components.llama_conversation.agent.requests.post') as requests_post_mock:
|
||||
patch('custom_components.llama_conversation.backends.tailored_openai.async_get_clientsession') as get_clientsession_mock:
|
||||
|
||||
entry_mock.return_value = config_entry
|
||||
get_exposed_entities_mock.return_value = (
|
||||
@@ -434,7 +438,7 @@ def text_generation_webui_agent_fixture(config_entry, hass, enable_custom_integr
|
||||
|
||||
response_mock = MagicMock()
|
||||
response_mock.json.return_value = { "model_name": config_entry.data[CONF_CHAT_MODEL] }
|
||||
requests_get_mock.return_value = response_mock
|
||||
get_clientsession_mock.get.return_value = response_mock
|
||||
|
||||
call_tool_mock.return_value = {"result": "success"}
|
||||
|
||||
@@ -444,8 +448,8 @@ def text_generation_webui_agent_fixture(config_entry, hass, enable_custom_integr
|
||||
)
|
||||
|
||||
all_mocks = {
|
||||
"requests_get": requests_get_mock,
|
||||
"requests_post": requests_post_mock
|
||||
"requests_get": get_clientsession_mock.get,
|
||||
"requests_post": get_clientsession_mock.post
|
||||
}
|
||||
|
||||
yield agent_obj, all_mocks
|
||||
@@ -490,7 +494,7 @@ async def test_text_generation_webui_agent(text_generation_webui_agent_fixture):
|
||||
# invoke the conversation agent
|
||||
conversation_id = "test-conversation"
|
||||
result = await text_generation_webui_agent.async_process(ConversationInput(
|
||||
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en"
|
||||
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
|
||||
))
|
||||
|
||||
# assert on results + check side effects
|
||||
@@ -521,7 +525,7 @@ async def test_text_generation_webui_agent(text_generation_webui_agent_fixture):
|
||||
|
||||
# do another turn of the same conversation and use a preset
|
||||
result = await text_generation_webui_agent.async_process(ConversationInput(
|
||||
"turn off the office lights", MagicMock(), conversation_id, None, "en"
|
||||
"turn off the office lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
|
||||
))
|
||||
|
||||
assert result.response.speech['plain']['speech'] == "I am saying something!"
|
||||
@@ -588,7 +592,7 @@ async def test_text_generation_webui_agent(text_generation_webui_agent_fixture):
|
||||
|
||||
# do another turn of the same conversation but the chat endpoint
|
||||
result = await text_generation_webui_agent.async_process(ConversationInput(
|
||||
"turn off the office lights", MagicMock(), conversation_id, None, "en"
|
||||
"turn off the office lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
|
||||
))
|
||||
|
||||
assert result.response.speech['plain']['speech'] == "I am saying something!"
|
||||
@@ -611,69 +615,6 @@ async def test_text_generation_webui_agent(text_generation_webui_agent_fixture):
|
||||
timeout=text_generation_webui_agent.entry.options[CONF_REQUEST_TIMEOUT]
|
||||
)
|
||||
|
||||
# reset mock stats
|
||||
for mock in all_mocks.values():
|
||||
mock.reset_mock()
|
||||
|
||||
text_generation_webui_agent.entry.options[CONF_TEXT_GEN_WEBUI_PRESET] = "Some Character"
|
||||
|
||||
# do another turn of the same conversation and use a preset
|
||||
result = await text_generation_webui_agent.async_process(ConversationInput(
|
||||
"turn off the office lights", MagicMock(), conversation_id, None, "en"
|
||||
))
|
||||
|
||||
assert result.response.speech['plain']['speech'] == "I am saying something!"
|
||||
|
||||
all_mocks["requests_post"].assert_called_once_with(
|
||||
"http://localhost:5000/v1/chat/completions",
|
||||
json={
|
||||
"model": text_generation_webui_agent.entry.data[CONF_CHAT_MODEL],
|
||||
"top_p": text_generation_webui_agent.entry.options[CONF_TOP_P],
|
||||
"top_k": text_generation_webui_agent.entry.options[CONF_TOP_K],
|
||||
"temperature": text_generation_webui_agent.entry.options[CONF_TEMPERATURE],
|
||||
"min_p": text_generation_webui_agent.entry.options[CONF_MIN_P],
|
||||
"typical_p": text_generation_webui_agent.entry.options[CONF_TYPICAL_P],
|
||||
"truncation_length": text_generation_webui_agent.entry.options[CONF_CONTEXT_LENGTH],
|
||||
"max_tokens": text_generation_webui_agent.entry.options[CONF_MAX_TOKENS],
|
||||
"mode": text_generation_webui_agent.entry.options[CONF_TEXT_GEN_WEBUI_CHAT_MODE],
|
||||
"character": "Some Character",
|
||||
"messages": ANY
|
||||
},
|
||||
headers={ "Authorization": "Bearer OpenAI-API-Key" },
|
||||
timeout=text_generation_webui_agent.entry.options[CONF_REQUEST_TIMEOUT]
|
||||
)
|
||||
|
||||
# reset mock stats
|
||||
for mock in all_mocks.values():
|
||||
mock.reset_mock()
|
||||
|
||||
text_generation_webui_agent.entry.options[CONF_TEXT_GEN_WEBUI_CHAT_MODE] = TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT
|
||||
|
||||
# do another turn of the same conversation and use instruct mode
|
||||
result = await text_generation_webui_agent.async_process(ConversationInput(
|
||||
"turn off the office lights", MagicMock(), conversation_id, None, "en"
|
||||
))
|
||||
|
||||
assert result.response.speech['plain']['speech'] == "I am saying something!"
|
||||
|
||||
all_mocks["requests_post"].assert_called_once_with(
|
||||
"http://localhost:5000/v1/chat/completions",
|
||||
json={
|
||||
"model": text_generation_webui_agent.entry.data[CONF_CHAT_MODEL],
|
||||
"top_p": text_generation_webui_agent.entry.options[CONF_TOP_P],
|
||||
"top_k": text_generation_webui_agent.entry.options[CONF_TOP_K],
|
||||
"min_p": text_generation_webui_agent.entry.options[CONF_MIN_P],
|
||||
"typical_p": text_generation_webui_agent.entry.options[CONF_TYPICAL_P],
|
||||
"temperature": text_generation_webui_agent.entry.options[CONF_TEMPERATURE],
|
||||
"truncation_length": text_generation_webui_agent.entry.options[CONF_CONTEXT_LENGTH],
|
||||
"max_tokens": text_generation_webui_agent.entry.options[CONF_MAX_TOKENS],
|
||||
"mode": text_generation_webui_agent.entry.options[CONF_TEXT_GEN_WEBUI_CHAT_MODE],
|
||||
"instruction_template": "chatml",
|
||||
"messages": ANY
|
||||
},
|
||||
headers={ "Authorization": "Bearer OpenAI-API-Key" },
|
||||
timeout=text_generation_webui_agent.entry.options[CONF_REQUEST_TIMEOUT]
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def generic_openai_agent_fixture(config_entry, hass, enable_custom_integrations):
|
||||
@@ -682,8 +623,7 @@ def generic_openai_agent_fixture(config_entry, hass, enable_custom_integrations)
|
||||
patch.object(GenericOpenAIAPIAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \
|
||||
patch.object(APIInstance, 'async_call_tool') as call_tool_mock, \
|
||||
patch('homeassistant.helpers.template.Template') as template_mock, \
|
||||
patch('custom_components.llama_conversation.agent.requests.get') as requests_get_mock, \
|
||||
patch('custom_components.llama_conversation.agent.requests.post') as requests_post_mock:
|
||||
patch('custom_components.llama_conversation.backends.generic_openai.async_get_clientsession') as get_clientsession_mock:
|
||||
|
||||
entry_mock.return_value = config_entry
|
||||
get_exposed_entities_mock.return_value = (
|
||||
@@ -704,8 +644,8 @@ def generic_openai_agent_fixture(config_entry, hass, enable_custom_integrations)
|
||||
)
|
||||
|
||||
all_mocks = {
|
||||
"requests_get": requests_get_mock,
|
||||
"requests_post": requests_post_mock
|
||||
"requests_get": get_clientsession_mock.get,
|
||||
"requests_post": get_clientsession_mock.post
|
||||
}
|
||||
|
||||
yield agent_obj, all_mocks
|
||||
@@ -745,7 +685,7 @@ async def test_generic_openai_agent(generic_openai_agent_fixture):
|
||||
# invoke the conversation agent
|
||||
conversation_id = "test-conversation"
|
||||
result = await generic_openai_agent.async_process(ConversationInput(
|
||||
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en"
|
||||
"turn on the kitchen lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
|
||||
))
|
||||
|
||||
# assert on results + check side effects
|
||||
@@ -804,7 +744,7 @@ async def test_generic_openai_agent(generic_openai_agent_fixture):
|
||||
|
||||
# do another turn of the same conversation but the chat endpoint
|
||||
result = await generic_openai_agent.async_process(ConversationInput(
|
||||
"turn off the office lights", MagicMock(), conversation_id, None, "en"
|
||||
"turn off the office lights", MagicMock(), conversation_id, None, "en", agent_id="test-agent"
|
||||
))
|
||||
|
||||
assert result.response.speech['plain']['speech'] == "I am saying something!"
|
||||
|
||||
Reference in New Issue
Block a user