Files
home-llm/custom_components/llama_conversation/backends/ollama.py

199 lines
8.0 KiB
Python

"""Defines the ollama compatible agent"""
from __future__ import annotations
import aiohttp
import asyncio
import json
import logging
from typing import Optional, Tuple, Dict, List, Any
from homeassistant.components import conversation as conversation
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL
from homeassistant.exceptions import ConfigEntryNotReady
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from custom_components.llama_conversation.utils import format_url
from custom_components.llama_conversation.const import (
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
CONF_TEMPERATURE,
CONF_TOP_K,
CONF_TOP_P,
CONF_TYPICAL_P,
CONF_REQUEST_TIMEOUT,
CONF_OPENAI_API_KEY,
CONF_REMOTE_USE_CHAT_ENDPOINT,
CONF_OLLAMA_KEEP_ALIVE_MIN,
CONF_OLLAMA_JSON_MODE,
CONF_CONTEXT_LENGTH,
DEFAULT_MAX_TOKENS,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_K,
DEFAULT_TOP_P,
DEFAULT_TYPICAL_P,
DEFAULT_REQUEST_TIMEOUT,
DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
DEFAULT_OLLAMA_KEEP_ALIVE_MIN,
DEFAULT_OLLAMA_JSON_MODE,
DEFAULT_CONTEXT_LENGTH,
)
from custom_components.llama_conversation.conversation import LocalLLMAgent, TextGenerationResult
_LOGGER = logging.getLogger(__name__)
class OllamaAPIAgent(LocalLLMAgent):
api_host: str
api_key: Optional[str]
model_name: Optional[str]
async def _async_load_model(self, entry: ConfigEntry) -> None:
self.api_host = format_url(
hostname=entry.data[CONF_HOST],
port=entry.data[CONF_PORT],
ssl=entry.data[CONF_SSL],
path=""
)
self.api_key = entry.data.get(CONF_OPENAI_API_KEY)
self.model_name = entry.data.get(CONF_CHAT_MODEL)
# ollama handles loading for us so just make sure the model is available
try:
headers = {}
session = async_get_clientsession(self.hass)
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
async with session.get(
f"{self.api_host}/api/tags",
headers=headers,
) as response:
response.raise_for_status()
currently_downloaded_result = await response.json()
except Exception as ex:
_LOGGER.debug("Connection error was: %s", repr(ex))
raise ConfigEntryNotReady("There was a problem connecting to the remote server") from ex
model_names = [ x["name"] for x in currently_downloaded_result["models"] ]
if ":" in self.model_name:
if not any([ name == self.model_name for name in model_names]):
raise ConfigEntryNotReady(f"Ollama server does not have the provided model: {self.model_name}")
elif not any([ name.split(":")[0] == self.model_name for name in model_names ]):
raise ConfigEntryNotReady(f"Ollama server does not have the provided model: {self.model_name}")
def _chat_completion_params(self, conversation: List[Dict[str, str]]) -> Tuple[str, Dict]:
request_params = {}
endpoint = "/api/chat"
request_params["messages"] = [ { "role": x["role"], "content": x["message"] } for x in conversation ]
return endpoint, request_params
def _completion_params(self, conversation: List[Dict[str, str]]) -> Tuple[str, Dict[str, Any]]:
request_params = {}
endpoint = "/api/generate"
request_params["prompt"] = self._format_prompt(conversation)
request_params["raw"] = True # ignore prompt template
return endpoint, request_params
def _extract_response(self, response_json: Dict) -> TextGenerationResult:
# TODO: this doesn't work because ollama caches prompts and doesn't always return the full prompt length
# context_len = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
# max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
# if response_json["prompt_eval_count"] + max_tokens > context_len:
# self._warn_context_size()
if "response" in response_json:
response = response_json["response"]
tool_calls = None
stop_reason = None
if response_json["done"] not in ["true", True]:
_LOGGER.warning("Model response did not end on a stop token (unfinished sentence)")
else:
response = response_json["message"]["content"]
tool_calls = response_json["message"].get("tool_calls")
stop_reason = response_json.get("done_reason")
return TextGenerationResult(
response=response, tool_calls=tool_calls, stop_reason=stop_reason, response_streamed=True
)
async def _async_generate(self, conversation: List[Dict[str, str]]) -> TextGenerationResult:
context_length = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
top_k = self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K)
typical_p = self.entry.options.get(CONF_TYPICAL_P, DEFAULT_TYPICAL_P)
timeout = self.entry.options.get(CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT)
keep_alive = self.entry.options.get(CONF_OLLAMA_KEEP_ALIVE_MIN, DEFAULT_OLLAMA_KEEP_ALIVE_MIN)
use_chat_api = self.entry.options.get(CONF_REMOTE_USE_CHAT_ENDPOINT, DEFAULT_REMOTE_USE_CHAT_ENDPOINT)
json_mode = self.entry.options.get(CONF_OLLAMA_JSON_MODE, DEFAULT_OLLAMA_JSON_MODE)
request_params = {
"model": self.model_name,
"stream": True,
"keep_alive": f"{keep_alive}m", # prevent ollama from unloading the model
"options": {
"num_ctx": context_length,
"top_p": top_p,
"top_k": top_k,
"typical_p": typical_p,
"temperature": temperature,
"num_preDict": max_tokens,
}
}
if json_mode:
request_params["format"] = "json"
if use_chat_api:
endpoint, additional_params = self._chat_completion_params(conversation)
else:
endpoint, additional_params = self._completion_params(conversation)
request_params.update(additional_params)
headers = {}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
session = async_get_clientsession(self.hass)
response = None
result = TextGenerationResult(
response="", response_streamed=True
)
try:
async with session.post(
f"{self.api_host}{endpoint}",
json=request_params,
timeout=timeout,
headers=headers
) as response:
response.raise_for_status()
while True:
chunk = await response.content.readline()
if not chunk:
break
parsed_chunk = self._extract_response(json.loads(chunk))
result.response += parsed_chunk.response
result.stop_reason = parsed_chunk.stop_reason
result.tool_calls = parsed_chunk.tool_calls
except asyncio.TimeoutError:
return TextGenerationResult(raise_error=True, error_msg="The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities.")
except aiohttp.ClientError as err:
_LOGGER.debug(f"Err was: {err}")
_LOGGER.debug(f"Request was: {request_params}")
_LOGGER.debug(f"Result was: {response}")
return TextGenerationResult(raise_error=True, error_msg=f"Failed to communicate with the API! {err}")
_LOGGER.debug(result)
return result