Use the ollama python client to better handle compatability

This commit is contained in:
Alex O'Connell
2025-12-13 18:26:04 -05:00
parent 25b6ddfd0c
commit 5f48b403d4
4 changed files with 168 additions and 114 deletions

View File

@@ -1,18 +1,19 @@
"""Defines the ollama compatible agent"""
"""Defines the Ollama compatible agent backed by the official python client."""
from __future__ import annotations
from warnings import deprecated
import aiohttp
import asyncio
import json
import logging
from typing import Optional, Tuple, Dict, List, Any, AsyncGenerator
import ssl
from collections.abc import Mapping
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
import certifi
import httpx
from ollama import AsyncClient, ChatResponse, ResponseError
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.components import conversation as conversation
from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import llm
from custom_components.llama_conversation.utils import format_url, get_oai_formatted_messages, get_oai_formatted_tools
@@ -23,119 +24,179 @@ from custom_components.llama_conversation.const import (
CONF_TOP_K,
CONF_TOP_P,
CONF_TYPICAL_P,
CONF_MIN_P,
CONF_ENABLE_THINK_MODE,
CONF_REQUEST_TIMEOUT,
CONF_OPENAI_API_KEY,
CONF_GENERIC_OPENAI_PATH,
CONF_OLLAMA_KEEP_ALIVE_MIN,
CONF_OLLAMA_JSON_MODE,
CONF_CONTEXT_LENGTH,
CONF_ENABLE_LEGACY_TOOL_CALLING,
DEFAULT_MAX_TOKENS,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_K,
DEFAULT_TOP_P,
DEFAULT_TYPICAL_P,
DEFAULT_MIN_P,
DEFAULT_ENABLE_THINK_MODE,
DEFAULT_REQUEST_TIMEOUT,
DEFAULT_GENERIC_OPENAI_PATH,
DEFAULT_OLLAMA_KEEP_ALIVE_MIN,
DEFAULT_OLLAMA_JSON_MODE,
DEFAULT_CONTEXT_LENGTH,
DEFAULT_ENABLE_LEGACY_TOOL_CALLING,
)
from custom_components.llama_conversation.entity import LocalLLMClient, TextGenerationResult
_LOGGER = logging.getLogger(__name__)
@deprecated("Use the built-in Ollama integration instead")
def _normalize_path(path: str | None) -> str:
if not path:
return ""
trimmed = str(path).strip("/")
return f"/{trimmed}" if trimmed else ""
def _build_default_ssl_context() -> ssl.SSLContext:
context = ssl.create_default_context()
try:
context.load_verify_locations(certifi.where())
except OSError as err:
_LOGGER.debug("Failed to load certifi bundle for Ollama client: %s", err)
return context
class OllamaAPIClient(LocalLLMClient):
api_host: str
api_key: Optional[str]
def __init__(self, hass: HomeAssistant, client_options: dict[str, Any]) -> None:
super().__init__(hass, client_options)
base_path = _normalize_path(client_options.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH))
self.api_host = format_url(
hostname=client_options[CONF_HOST],
port=client_options[CONF_PORT],
ssl=client_options[CONF_SSL],
path=client_options.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
path=base_path,
)
self.api_key = client_options.get(CONF_OPENAI_API_KEY) or None
self._headers = {"Authorization": f"Bearer {self.api_key}"} if self.api_key else None
self._ssl_context = _build_default_ssl_context() if client_options.get(CONF_SSL) else None
self.api_key = client_options.get(CONF_OPENAI_API_KEY, "")
def _build_client(self, *, timeout: float | int | httpx.Timeout | None = None) -> AsyncClient:
timeout_config: httpx.Timeout | float | None = timeout
if isinstance(timeout, (int, float)):
timeout_config = httpx.Timeout(timeout)
return AsyncClient(
host=self.api_host,
headers=self._headers,
timeout=timeout_config,
verify=self._ssl_context,
)
@staticmethod
def get_name(client_options: dict[str, Any]):
host = client_options[CONF_HOST]
port = client_options[CONF_PORT]
ssl = client_options[CONF_SSL]
path = "/" + client_options[CONF_GENERIC_OPENAI_PATH]
path = _normalize_path(client_options.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH))
return f"Ollama at '{format_url(hostname=host, port=port, ssl=ssl, path=path)}'"
@staticmethod
async def async_validate_connection(hass: HomeAssistant, user_input: Dict[str, Any]) -> str | None:
headers = {}
api_key = user_input.get(CONF_OPENAI_API_KEY)
api_base_path = user_input.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
base_path = _normalize_path(user_input.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH))
timeout_config: httpx.Timeout | float | None = httpx.Timeout(5)
verify_context = None
if user_input.get(CONF_SSL):
verify_context = await hass.async_add_executor_job(_build_default_ssl_context)
client = AsyncClient(
host=format_url(
hostname=user_input[CONF_HOST],
port=user_input[CONF_PORT],
ssl=user_input[CONF_SSL],
path=base_path,
),
headers={"Authorization": f"Bearer {api_key}"} if api_key else None,
timeout=timeout_config,
verify=verify_context,
)
try:
session = async_get_clientsession(hass)
async with session.get(
format_url(
hostname=user_input[CONF_HOST],
port=user_input[CONF_PORT],
ssl=user_input[CONF_SSL],
path=f"/{api_base_path}/api/tags"
),
timeout=aiohttp.ClientTimeout(total=5), # quick timeout
headers=headers
) as response:
if response.ok:
return None
else:
return f"HTTP Status {response.status}"
except Exception as ex:
return str(ex)
await client.list()
except httpx.TimeoutException:
return "Connection timed out"
except ResponseError as err:
return f"HTTP Status {err.status_code}: {err.error}"
except ConnectionError as err:
return str(err)
return None
async def async_get_available_models(self) -> List[str]:
headers = {}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
client = self._build_client(timeout=5)
try:
response = await client.list()
except httpx.TimeoutException as err:
raise HomeAssistantError("Timed out while fetching models from the Ollama server") from err
except (ResponseError, ConnectionError) as err:
raise HomeAssistantError(f"Failed to fetch models from the Ollama server: {err}") from err
session = async_get_clientsession(self.hass)
async with session.get(
f"{self.api_host}/api/tags",
timeout=aiohttp.ClientTimeout(total=5), # quick timeout
headers=headers
) as response:
response.raise_for_status()
models_result = await response.json()
models: List[str] = []
for model in getattr(response, "models", []) or []:
candidate = getattr(model, "name", None) or getattr(model, "model", None)
if candidate:
models.append(candidate)
return [x["name"] for x in models_result["models"]]
return models
def _extract_response(self, response_json: Dict) -> Tuple[Optional[str], Optional[List[llm.ToolInput]]]:
# TODO: this doesn't work because ollama caches prompts and doesn't always return the full prompt length
# context_len = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
# max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
# if response_json["prompt_eval_count"] + max_tokens > context_len:
# self._warn_context_size()
def _extract_response(self, response_chunk: ChatResponse) -> Tuple[Optional[str], Optional[List[llm.ToolInput]]]:
message = getattr(response_chunk, "message", None)
content = getattr(message, "content", None) if message else None
raw_tool_calls = getattr(message, "tool_calls", None) if message else None
if "response" in response_json:
response = response_json["response"]
tool_calls = None
stop_reason = None
if response_json["done"] not in ["true", True]:
_LOGGER.warning("Model response did not end on a stop token (unfinished sentence)")
else:
response = response_json["message"]["content"]
raw_tool_calls = response_json["message"].get("tool_calls")
tool_calls = [ llm.ToolInput(tool_name=x["function"]["name"], tool_args=x["function"]["arguments"]) for x in raw_tool_calls] if raw_tool_calls else None
stop_reason = response_json.get("done_reason")
tool_calls: Optional[List[llm.ToolInput]] = None
if raw_tool_calls:
parsed_tool_calls: list[llm.ToolInput] = []
for tool_call in raw_tool_calls:
function = getattr(tool_call, "function", None)
name = getattr(function, "name", None) if function else None
if not name:
continue
# _LOGGER.debug(f"{response=} {tool_calls=}")
arguments = getattr(function, "arguments", None) or {}
if isinstance(arguments, Mapping):
arguments_dict = dict(arguments)
else:
arguments_dict = {"raw": arguments}
return response, tool_calls
parsed_tool_calls.append(llm.ToolInput(tool_name=name, tool_args=arguments_dict))
def _generate_stream(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None, agent_id: str, entity_options: Dict[str, Any]) -> AsyncGenerator[TextGenerationResult, None]:
if parsed_tool_calls:
tool_calls = parsed_tool_calls
if content is None and not tool_calls:
return None, None
return content, tool_calls
@staticmethod
def _format_keep_alive(value: Any) -> Any:
as_text = str(value).strip()
return 0 if as_text in {"0", "0.0"} else f"{as_text}m"
def _generate_stream(
self,
conversation: List[conversation.Content],
llm_api: llm.APIInstance | None,
agent_id: str,
entity_options: Dict[str, Any],
) -> AsyncGenerator[TextGenerationResult, None]:
model_name = entity_options.get(CONF_CHAT_MODEL, "")
context_length = entity_options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
max_tokens = entity_options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
@@ -145,58 +206,47 @@ class OllamaAPIClient(LocalLLMClient):
typical_p = entity_options.get(CONF_TYPICAL_P, DEFAULT_TYPICAL_P)
timeout = entity_options.get(CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT)
keep_alive = entity_options.get(CONF_OLLAMA_KEEP_ALIVE_MIN, DEFAULT_OLLAMA_KEEP_ALIVE_MIN)
legacy_tool_calling = entity_options.get(CONF_ENABLE_LEGACY_TOOL_CALLING, DEFAULT_ENABLE_LEGACY_TOOL_CALLING)
think_mode = entity_options.get(CONF_ENABLE_THINK_MODE, DEFAULT_ENABLE_THINK_MODE)
json_mode = entity_options.get(CONF_OLLAMA_JSON_MODE, DEFAULT_OLLAMA_JSON_MODE)
request_params = {
"model": model_name,
"stream": True,
"keep_alive": f"{keep_alive}m", # prevent ollama from unloading the model
"options": {
"num_ctx": context_length,
"top_p": top_p,
"top_k": top_k,
"typical_p": typical_p,
"temperature": temperature,
"num_predict": max_tokens,
},
options = {
"num_ctx": context_length,
"top_p": top_p,
"top_k": top_k,
"typical_p": typical_p,
"temperature": temperature,
"num_predict": max_tokens,
"min_p": entity_options.get(CONF_MIN_P, DEFAULT_MIN_P),
}
if json_mode:
request_params["format"] = "json"
if llm_api:
request_params["tools"] = get_oai_formatted_tools(llm_api, self._async_get_all_exposed_domains())
endpoint = "/api/chat"
request_params["messages"] = get_oai_formatted_messages(conversation, tool_args_to_str=False)
headers = {}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
session = async_get_clientsession(self.hass)
messages = get_oai_formatted_messages(conversation, tool_args_to_str=False)
tools = None
if llm_api and not legacy_tool_calling:
tools = get_oai_formatted_tools(llm_api, self._async_get_all_exposed_domains())
keep_alive_payload = self._format_keep_alive(keep_alive)
async def anext_token() -> AsyncGenerator[Tuple[Optional[str], Optional[List[llm.ToolInput]]], None]:
response = None
chunk = None
client = self._build_client(timeout=timeout)
try:
async with session.post(
f"{self.api_host}{endpoint}",
json=request_params,
timeout=aiohttp.ClientTimeout(total=timeout),
headers=headers
) as response:
response.raise_for_status()
while True:
chunk = await response.content.readline()
if not chunk:
break
yield self._extract_response(json.loads(chunk))
except asyncio.TimeoutError as err:
raise HomeAssistantError("The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities.") from err
except aiohttp.ClientError as err:
stream = await client.chat(
model=model_name,
messages=messages,
tools=tools,
stream=True,
think=think_mode,
format="json" if json_mode else None,
options=options,
keep_alive=keep_alive_payload,
)
async for chunk in stream:
yield self._extract_response(chunk)
except httpx.TimeoutException as err:
raise HomeAssistantError(
"The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities."
) from err
except (ResponseError, ConnectionError) as err:
raise HomeAssistantError(f"Failed to communicate with the API! {err}") from err
return self._async_parse_completion(llm_api, agent_id, entity_options, anext_token=anext_token())

View File

@@ -104,6 +104,8 @@ CONF_TEMPERATURE = "temperature"
DEFAULT_TEMPERATURE = 0.1
CONF_REQUEST_TIMEOUT = "request_timeout"
DEFAULT_REQUEST_TIMEOUT = 90
CONF_ENABLE_THINK_MODE = "enable_think_mode"
DEFAULT_ENABLE_THINK_MODE = False
CONF_BACKEND_TYPE = "model_backend"
BACKEND_TYPE_LLAMA_HF_OLD = "llama_cpp_hf"
BACKEND_TYPE_LLAMA_EXISTING_OLD = "llama_cpp_existing"
@@ -185,7 +187,7 @@ DEFAULT_GENERIC_OPENAI_PATH = "v1"
CONF_GENERIC_OPENAI_VALIDATE_MODEL = "openai_validate_model"
DEFAULT_GENERIC_OPENAI_VALIDATE_MODEL = True
CONF_CONTEXT_LENGTH = "context_length"
DEFAULT_CONTEXT_LENGTH = 2048
DEFAULT_CONTEXT_LENGTH = 8192
CONF_LLAMACPP_BATCH_SIZE = "batch_size"
DEFAULT_LLAMACPP_BATCH_SIZE = 512
CONF_LLAMACPP_THREAD_COUNT = "n_threads"

View File

@@ -11,6 +11,7 @@
"iot_class": "local_polling",
"requirements": [
"huggingface-hub>=0.23.0",
"webcolors>=24.8.0"
"webcolors>=24.8.0",
"ollama>=0.5.1"
]
}

View File

@@ -1,2 +1,3 @@
huggingface-hub>=0.23.0
webcolors>=24.8.0
ollama>=0.5.1