mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-08 21:28:05 -05:00
Use the ollama python client to better handle compatability
This commit is contained in:
@@ -1,18 +1,19 @@
|
||||
"""Defines the ollama compatible agent"""
|
||||
"""Defines the Ollama compatible agent backed by the official python client."""
|
||||
from __future__ import annotations
|
||||
from warnings import deprecated
|
||||
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Tuple, Dict, List, Any, AsyncGenerator
|
||||
import ssl
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
|
||||
import certifi
|
||||
import httpx
|
||||
from ollama import AsyncClient, ChatResponse, ResponseError
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.components import conversation as conversation
|
||||
from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import llm
|
||||
|
||||
from custom_components.llama_conversation.utils import format_url, get_oai_formatted_messages, get_oai_formatted_tools
|
||||
@@ -23,119 +24,179 @@ from custom_components.llama_conversation.const import (
|
||||
CONF_TOP_K,
|
||||
CONF_TOP_P,
|
||||
CONF_TYPICAL_P,
|
||||
CONF_MIN_P,
|
||||
CONF_ENABLE_THINK_MODE,
|
||||
CONF_REQUEST_TIMEOUT,
|
||||
CONF_OPENAI_API_KEY,
|
||||
CONF_GENERIC_OPENAI_PATH,
|
||||
CONF_OLLAMA_KEEP_ALIVE_MIN,
|
||||
CONF_OLLAMA_JSON_MODE,
|
||||
CONF_CONTEXT_LENGTH,
|
||||
CONF_ENABLE_LEGACY_TOOL_CALLING,
|
||||
DEFAULT_MAX_TOKENS,
|
||||
DEFAULT_TEMPERATURE,
|
||||
DEFAULT_TOP_K,
|
||||
DEFAULT_TOP_P,
|
||||
DEFAULT_TYPICAL_P,
|
||||
DEFAULT_MIN_P,
|
||||
DEFAULT_ENABLE_THINK_MODE,
|
||||
DEFAULT_REQUEST_TIMEOUT,
|
||||
DEFAULT_GENERIC_OPENAI_PATH,
|
||||
DEFAULT_OLLAMA_KEEP_ALIVE_MIN,
|
||||
DEFAULT_OLLAMA_JSON_MODE,
|
||||
DEFAULT_CONTEXT_LENGTH,
|
||||
DEFAULT_ENABLE_LEGACY_TOOL_CALLING,
|
||||
)
|
||||
|
||||
from custom_components.llama_conversation.entity import LocalLLMClient, TextGenerationResult
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@deprecated("Use the built-in Ollama integration instead")
|
||||
|
||||
def _normalize_path(path: str | None) -> str:
|
||||
if not path:
|
||||
return ""
|
||||
trimmed = str(path).strip("/")
|
||||
return f"/{trimmed}" if trimmed else ""
|
||||
|
||||
|
||||
def _build_default_ssl_context() -> ssl.SSLContext:
|
||||
context = ssl.create_default_context()
|
||||
try:
|
||||
context.load_verify_locations(certifi.where())
|
||||
except OSError as err:
|
||||
_LOGGER.debug("Failed to load certifi bundle for Ollama client: %s", err)
|
||||
return context
|
||||
|
||||
class OllamaAPIClient(LocalLLMClient):
|
||||
api_host: str
|
||||
api_key: Optional[str]
|
||||
|
||||
def __init__(self, hass: HomeAssistant, client_options: dict[str, Any]) -> None:
|
||||
super().__init__(hass, client_options)
|
||||
base_path = _normalize_path(client_options.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH))
|
||||
self.api_host = format_url(
|
||||
hostname=client_options[CONF_HOST],
|
||||
port=client_options[CONF_PORT],
|
||||
ssl=client_options[CONF_SSL],
|
||||
path=client_options.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
|
||||
path=base_path,
|
||||
)
|
||||
self.api_key = client_options.get(CONF_OPENAI_API_KEY) or None
|
||||
self._headers = {"Authorization": f"Bearer {self.api_key}"} if self.api_key else None
|
||||
self._ssl_context = _build_default_ssl_context() if client_options.get(CONF_SSL) else None
|
||||
|
||||
self.api_key = client_options.get(CONF_OPENAI_API_KEY, "")
|
||||
def _build_client(self, *, timeout: float | int | httpx.Timeout | None = None) -> AsyncClient:
|
||||
timeout_config: httpx.Timeout | float | None = timeout
|
||||
if isinstance(timeout, (int, float)):
|
||||
timeout_config = httpx.Timeout(timeout)
|
||||
|
||||
return AsyncClient(
|
||||
host=self.api_host,
|
||||
headers=self._headers,
|
||||
timeout=timeout_config,
|
||||
verify=self._ssl_context,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_name(client_options: dict[str, Any]):
|
||||
host = client_options[CONF_HOST]
|
||||
port = client_options[CONF_PORT]
|
||||
ssl = client_options[CONF_SSL]
|
||||
path = "/" + client_options[CONF_GENERIC_OPENAI_PATH]
|
||||
path = _normalize_path(client_options.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH))
|
||||
return f"Ollama at '{format_url(hostname=host, port=port, ssl=ssl, path=path)}'"
|
||||
|
||||
@staticmethod
|
||||
async def async_validate_connection(hass: HomeAssistant, user_input: Dict[str, Any]) -> str | None:
|
||||
headers = {}
|
||||
api_key = user_input.get(CONF_OPENAI_API_KEY)
|
||||
api_base_path = user_input.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
base_path = _normalize_path(user_input.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH))
|
||||
timeout_config: httpx.Timeout | float | None = httpx.Timeout(5)
|
||||
|
||||
verify_context = None
|
||||
if user_input.get(CONF_SSL):
|
||||
verify_context = await hass.async_add_executor_job(_build_default_ssl_context)
|
||||
|
||||
client = AsyncClient(
|
||||
host=format_url(
|
||||
hostname=user_input[CONF_HOST],
|
||||
port=user_input[CONF_PORT],
|
||||
ssl=user_input[CONF_SSL],
|
||||
path=base_path,
|
||||
),
|
||||
headers={"Authorization": f"Bearer {api_key}"} if api_key else None,
|
||||
timeout=timeout_config,
|
||||
verify=verify_context,
|
||||
)
|
||||
|
||||
try:
|
||||
session = async_get_clientsession(hass)
|
||||
async with session.get(
|
||||
format_url(
|
||||
hostname=user_input[CONF_HOST],
|
||||
port=user_input[CONF_PORT],
|
||||
ssl=user_input[CONF_SSL],
|
||||
path=f"/{api_base_path}/api/tags"
|
||||
),
|
||||
timeout=aiohttp.ClientTimeout(total=5), # quick timeout
|
||||
headers=headers
|
||||
) as response:
|
||||
if response.ok:
|
||||
return None
|
||||
else:
|
||||
return f"HTTP Status {response.status}"
|
||||
except Exception as ex:
|
||||
return str(ex)
|
||||
|
||||
await client.list()
|
||||
except httpx.TimeoutException:
|
||||
return "Connection timed out"
|
||||
except ResponseError as err:
|
||||
return f"HTTP Status {err.status_code}: {err.error}"
|
||||
except ConnectionError as err:
|
||||
return str(err)
|
||||
|
||||
return None
|
||||
|
||||
async def async_get_available_models(self) -> List[str]:
|
||||
headers = {}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
client = self._build_client(timeout=5)
|
||||
try:
|
||||
response = await client.list()
|
||||
except httpx.TimeoutException as err:
|
||||
raise HomeAssistantError("Timed out while fetching models from the Ollama server") from err
|
||||
except (ResponseError, ConnectionError) as err:
|
||||
raise HomeAssistantError(f"Failed to fetch models from the Ollama server: {err}") from err
|
||||
|
||||
session = async_get_clientsession(self.hass)
|
||||
async with session.get(
|
||||
f"{self.api_host}/api/tags",
|
||||
timeout=aiohttp.ClientTimeout(total=5), # quick timeout
|
||||
headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
models_result = await response.json()
|
||||
models: List[str] = []
|
||||
for model in getattr(response, "models", []) or []:
|
||||
candidate = getattr(model, "name", None) or getattr(model, "model", None)
|
||||
if candidate:
|
||||
models.append(candidate)
|
||||
|
||||
return [x["name"] for x in models_result["models"]]
|
||||
return models
|
||||
|
||||
def _extract_response(self, response_json: Dict) -> Tuple[Optional[str], Optional[List[llm.ToolInput]]]:
|
||||
# TODO: this doesn't work because ollama caches prompts and doesn't always return the full prompt length
|
||||
# context_len = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
|
||||
# max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
||||
# if response_json["prompt_eval_count"] + max_tokens > context_len:
|
||||
# self._warn_context_size()
|
||||
def _extract_response(self, response_chunk: ChatResponse) -> Tuple[Optional[str], Optional[List[llm.ToolInput]]]:
|
||||
message = getattr(response_chunk, "message", None)
|
||||
content = getattr(message, "content", None) if message else None
|
||||
raw_tool_calls = getattr(message, "tool_calls", None) if message else None
|
||||
|
||||
if "response" in response_json:
|
||||
response = response_json["response"]
|
||||
tool_calls = None
|
||||
stop_reason = None
|
||||
if response_json["done"] not in ["true", True]:
|
||||
_LOGGER.warning("Model response did not end on a stop token (unfinished sentence)")
|
||||
else:
|
||||
response = response_json["message"]["content"]
|
||||
raw_tool_calls = response_json["message"].get("tool_calls")
|
||||
tool_calls = [ llm.ToolInput(tool_name=x["function"]["name"], tool_args=x["function"]["arguments"]) for x in raw_tool_calls] if raw_tool_calls else None
|
||||
stop_reason = response_json.get("done_reason")
|
||||
tool_calls: Optional[List[llm.ToolInput]] = None
|
||||
if raw_tool_calls:
|
||||
parsed_tool_calls: list[llm.ToolInput] = []
|
||||
for tool_call in raw_tool_calls:
|
||||
function = getattr(tool_call, "function", None)
|
||||
name = getattr(function, "name", None) if function else None
|
||||
if not name:
|
||||
continue
|
||||
|
||||
# _LOGGER.debug(f"{response=} {tool_calls=}")
|
||||
arguments = getattr(function, "arguments", None) or {}
|
||||
if isinstance(arguments, Mapping):
|
||||
arguments_dict = dict(arguments)
|
||||
else:
|
||||
arguments_dict = {"raw": arguments}
|
||||
|
||||
return response, tool_calls
|
||||
parsed_tool_calls.append(llm.ToolInput(tool_name=name, tool_args=arguments_dict))
|
||||
|
||||
def _generate_stream(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None, agent_id: str, entity_options: Dict[str, Any]) -> AsyncGenerator[TextGenerationResult, None]:
|
||||
if parsed_tool_calls:
|
||||
tool_calls = parsed_tool_calls
|
||||
|
||||
if content is None and not tool_calls:
|
||||
return None, None
|
||||
|
||||
return content, tool_calls
|
||||
|
||||
@staticmethod
|
||||
def _format_keep_alive(value: Any) -> Any:
|
||||
as_text = str(value).strip()
|
||||
return 0 if as_text in {"0", "0.0"} else f"{as_text}m"
|
||||
|
||||
def _generate_stream(
|
||||
self,
|
||||
conversation: List[conversation.Content],
|
||||
llm_api: llm.APIInstance | None,
|
||||
agent_id: str,
|
||||
entity_options: Dict[str, Any],
|
||||
) -> AsyncGenerator[TextGenerationResult, None]:
|
||||
model_name = entity_options.get(CONF_CHAT_MODEL, "")
|
||||
context_length = entity_options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
|
||||
max_tokens = entity_options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
||||
@@ -145,58 +206,47 @@ class OllamaAPIClient(LocalLLMClient):
|
||||
typical_p = entity_options.get(CONF_TYPICAL_P, DEFAULT_TYPICAL_P)
|
||||
timeout = entity_options.get(CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT)
|
||||
keep_alive = entity_options.get(CONF_OLLAMA_KEEP_ALIVE_MIN, DEFAULT_OLLAMA_KEEP_ALIVE_MIN)
|
||||
legacy_tool_calling = entity_options.get(CONF_ENABLE_LEGACY_TOOL_CALLING, DEFAULT_ENABLE_LEGACY_TOOL_CALLING)
|
||||
think_mode = entity_options.get(CONF_ENABLE_THINK_MODE, DEFAULT_ENABLE_THINK_MODE)
|
||||
json_mode = entity_options.get(CONF_OLLAMA_JSON_MODE, DEFAULT_OLLAMA_JSON_MODE)
|
||||
|
||||
request_params = {
|
||||
"model": model_name,
|
||||
"stream": True,
|
||||
"keep_alive": f"{keep_alive}m", # prevent ollama from unloading the model
|
||||
"options": {
|
||||
"num_ctx": context_length,
|
||||
"top_p": top_p,
|
||||
"top_k": top_k,
|
||||
"typical_p": typical_p,
|
||||
"temperature": temperature,
|
||||
"num_predict": max_tokens,
|
||||
},
|
||||
options = {
|
||||
"num_ctx": context_length,
|
||||
"top_p": top_p,
|
||||
"top_k": top_k,
|
||||
"typical_p": typical_p,
|
||||
"temperature": temperature,
|
||||
"num_predict": max_tokens,
|
||||
"min_p": entity_options.get(CONF_MIN_P, DEFAULT_MIN_P),
|
||||
}
|
||||
|
||||
if json_mode:
|
||||
request_params["format"] = "json"
|
||||
|
||||
if llm_api:
|
||||
request_params["tools"] = get_oai_formatted_tools(llm_api, self._async_get_all_exposed_domains())
|
||||
|
||||
endpoint = "/api/chat"
|
||||
request_params["messages"] = get_oai_formatted_messages(conversation, tool_args_to_str=False)
|
||||
|
||||
headers = {}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
session = async_get_clientsession(self.hass)
|
||||
messages = get_oai_formatted_messages(conversation, tool_args_to_str=False)
|
||||
tools = None
|
||||
if llm_api and not legacy_tool_calling:
|
||||
tools = get_oai_formatted_tools(llm_api, self._async_get_all_exposed_domains())
|
||||
keep_alive_payload = self._format_keep_alive(keep_alive)
|
||||
|
||||
async def anext_token() -> AsyncGenerator[Tuple[Optional[str], Optional[List[llm.ToolInput]]], None]:
|
||||
response = None
|
||||
chunk = None
|
||||
client = self._build_client(timeout=timeout)
|
||||
try:
|
||||
async with session.post(
|
||||
f"{self.api_host}{endpoint}",
|
||||
json=request_params,
|
||||
timeout=aiohttp.ClientTimeout(total=timeout),
|
||||
headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
while True:
|
||||
chunk = await response.content.readline()
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
yield self._extract_response(json.loads(chunk))
|
||||
except asyncio.TimeoutError as err:
|
||||
raise HomeAssistantError("The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities.") from err
|
||||
except aiohttp.ClientError as err:
|
||||
stream = await client.chat(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
stream=True,
|
||||
think=think_mode,
|
||||
format="json" if json_mode else None,
|
||||
options=options,
|
||||
keep_alive=keep_alive_payload,
|
||||
)
|
||||
|
||||
async for chunk in stream:
|
||||
yield self._extract_response(chunk)
|
||||
except httpx.TimeoutException as err:
|
||||
raise HomeAssistantError(
|
||||
"The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities."
|
||||
) from err
|
||||
except (ResponseError, ConnectionError) as err:
|
||||
raise HomeAssistantError(f"Failed to communicate with the API! {err}") from err
|
||||
|
||||
return self._async_parse_completion(llm_api, agent_id, entity_options, anext_token=anext_token())
|
||||
|
||||
@@ -104,6 +104,8 @@ CONF_TEMPERATURE = "temperature"
|
||||
DEFAULT_TEMPERATURE = 0.1
|
||||
CONF_REQUEST_TIMEOUT = "request_timeout"
|
||||
DEFAULT_REQUEST_TIMEOUT = 90
|
||||
CONF_ENABLE_THINK_MODE = "enable_think_mode"
|
||||
DEFAULT_ENABLE_THINK_MODE = False
|
||||
CONF_BACKEND_TYPE = "model_backend"
|
||||
BACKEND_TYPE_LLAMA_HF_OLD = "llama_cpp_hf"
|
||||
BACKEND_TYPE_LLAMA_EXISTING_OLD = "llama_cpp_existing"
|
||||
@@ -185,7 +187,7 @@ DEFAULT_GENERIC_OPENAI_PATH = "v1"
|
||||
CONF_GENERIC_OPENAI_VALIDATE_MODEL = "openai_validate_model"
|
||||
DEFAULT_GENERIC_OPENAI_VALIDATE_MODEL = True
|
||||
CONF_CONTEXT_LENGTH = "context_length"
|
||||
DEFAULT_CONTEXT_LENGTH = 2048
|
||||
DEFAULT_CONTEXT_LENGTH = 8192
|
||||
CONF_LLAMACPP_BATCH_SIZE = "batch_size"
|
||||
DEFAULT_LLAMACPP_BATCH_SIZE = 512
|
||||
CONF_LLAMACPP_THREAD_COUNT = "n_threads"
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
"iot_class": "local_polling",
|
||||
"requirements": [
|
||||
"huggingface-hub>=0.23.0",
|
||||
"webcolors>=24.8.0"
|
||||
"webcolors>=24.8.0",
|
||||
"ollama>=0.5.1"
|
||||
]
|
||||
}
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
huggingface-hub>=0.23.0
|
||||
webcolors>=24.8.0
|
||||
ollama>=0.5.1
|
||||
|
||||
Reference in New Issue
Block a user