Files
home-llm/custom_components/llama_conversation/backends/ollama.py
2025-09-15 22:10:25 -04:00

175 lines
7.4 KiB
Python

"""Defines the ollama compatible agent"""
from __future__ import annotations
from warnings import deprecated
import aiohttp
import asyncio
import json
import logging
from typing import Optional, Tuple, Dict, List, Any, AsyncGenerator
from homeassistant.components import conversation as conversation
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL
from homeassistant.exceptions import ConfigEntryNotReady
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers import llm
from custom_components.llama_conversation.utils import format_url, get_oai_formatted_messages, get_oai_formatted_tools
from custom_components.llama_conversation.const import (
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
CONF_TEMPERATURE,
CONF_TOP_K,
CONF_TOP_P,
CONF_TYPICAL_P,
CONF_REQUEST_TIMEOUT,
CONF_OPENAI_API_KEY,
CONF_OLLAMA_KEEP_ALIVE_MIN,
CONF_OLLAMA_JSON_MODE,
CONF_CONTEXT_LENGTH,
DEFAULT_MAX_TOKENS,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_K,
DEFAULT_TOP_P,
DEFAULT_TYPICAL_P,
DEFAULT_REQUEST_TIMEOUT,
DEFAULT_OLLAMA_KEEP_ALIVE_MIN,
DEFAULT_OLLAMA_JSON_MODE,
DEFAULT_CONTEXT_LENGTH,
)
from custom_components.llama_conversation.conversation import LocalLLMAgent, TextGenerationResult
_LOGGER = logging.getLogger(__name__)
@deprecated("Use the built-in Ollama integration instead")
class OllamaAPIAgent(LocalLLMAgent):
api_host: str
api_key: Optional[str]
model_name: Optional[str]
async def _async_load_model(self, entry: ConfigEntry) -> None:
self.api_host = format_url(
hostname=entry.data[CONF_HOST],
port=entry.data[CONF_PORT],
ssl=entry.data[CONF_SSL],
path=""
)
self.api_key = entry.data.get(CONF_OPENAI_API_KEY)
self.model_name = entry.data.get(CONF_CHAT_MODEL)
# ollama handles loading for us so just make sure the model is available
try:
headers = {}
session = async_get_clientsession(self.hass)
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
async with session.get(
f"{self.api_host}/api/tags",
headers=headers,
) as response:
response.raise_for_status()
currently_downloaded_result = await response.json()
except Exception as ex:
_LOGGER.debug("Connection error was: %s", repr(ex))
raise ConfigEntryNotReady("There was a problem connecting to the remote server") from ex
model_names = [ x["name"] for x in currently_downloaded_result["models"] ]
if ":" in self.model_name:
if not any([ name == self.model_name for name in model_names]):
raise ConfigEntryNotReady(f"Ollama server does not have the provided model: {self.model_name}")
elif not any([ name.split(":")[0] == self.model_name for name in model_names ]):
raise ConfigEntryNotReady(f"Ollama server does not have the provided model: {self.model_name}")
def _extract_response(self, response_json: Dict) -> TextGenerationResult:
# TODO: this doesn't work because ollama caches prompts and doesn't always return the full prompt length
# context_len = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
# max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
# if response_json["prompt_eval_count"] + max_tokens > context_len:
# self._warn_context_size()
if "response" in response_json:
response = response_json["response"]
tool_calls = None
stop_reason = None
if response_json["done"] not in ["true", True]:
_LOGGER.warning("Model response did not end on a stop token (unfinished sentence)")
else:
response = response_json["message"]["content"]
tool_calls = response_json["message"].get("tool_calls")
stop_reason = response_json.get("done_reason")
return TextGenerationResult(
response=response, tool_calls=tool_calls, stop_reason=stop_reason, response_streamed=True
)
async def _async_generate_with_parameters(self, endpoint: str, request_params: dict[str, Any], headers: dict[str, Any], timeout: int) -> AsyncGenerator[TextGenerationResult, None]:
session = async_get_clientsession(self.hass)
response = None
try:
async with session.post(
f"{self.api_host}{endpoint}",
json=request_params,
timeout=timeout,
headers=headers
) as response:
response.raise_for_status()
while True:
chunk = await response.content.readline()
if not chunk:
break
yield self._extract_response(json.loads(chunk))
except asyncio.TimeoutError:
yield TextGenerationResult(raise_error=True, error_msg="The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities.")
except aiohttp.ClientError as err:
_LOGGER.debug(f"Err was: {err}")
_LOGGER.debug(f"Request was: {request_params}")
_LOGGER.debug(f"Result was: {response}")
yield TextGenerationResult(raise_error=True, error_msg=f"Failed to communicate with the API! {err}")
def _generate_stream(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None, user_input: conversation.ConversationInput) -> AsyncGenerator[TextGenerationResult, None]:
context_length = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
top_k = self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K)
typical_p = self.entry.options.get(CONF_TYPICAL_P, DEFAULT_TYPICAL_P)
timeout = self.entry.options.get(CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT)
keep_alive = self.entry.options.get(CONF_OLLAMA_KEEP_ALIVE_MIN, DEFAULT_OLLAMA_KEEP_ALIVE_MIN)
json_mode = self.entry.options.get(CONF_OLLAMA_JSON_MODE, DEFAULT_OLLAMA_JSON_MODE)
request_params = {
"model": self.model_name,
"stream": True,
"keep_alive": f"{keep_alive}m", # prevent ollama from unloading the model
"options": {
"num_ctx": context_length,
"top_p": top_p,
"top_k": top_k,
"typical_p": typical_p,
"temperature": temperature,
"num_predict": max_tokens,
},
}
if json_mode:
request_params["format"] = "json"
if llm_api:
request_params["tools"] = get_oai_formatted_tools(llm_api, self._async_get_all_exposed_domains())
endpoint = "/api/chat"
request_params["messages"] = get_oai_formatted_messages(conversation)
headers = {}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
return self._async_generate_with_parameters(endpoint, request_params, headers, timeout)