mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-09 21:58:00 -05:00
199 lines
8.0 KiB
Python
199 lines
8.0 KiB
Python
"""Defines the ollama compatible agent"""
|
|
from __future__ import annotations
|
|
|
|
import aiohttp
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from typing import Optional, Tuple, Dict, List, Any
|
|
|
|
from homeassistant.components import conversation as conversation
|
|
from homeassistant.config_entries import ConfigEntry
|
|
from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL
|
|
from homeassistant.exceptions import ConfigEntryNotReady
|
|
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
|
|
|
from custom_components.llama_conversation.utils import format_url
|
|
from custom_components.llama_conversation.const import (
|
|
CONF_CHAT_MODEL,
|
|
CONF_MAX_TOKENS,
|
|
CONF_TEMPERATURE,
|
|
CONF_TOP_K,
|
|
CONF_TOP_P,
|
|
CONF_TYPICAL_P,
|
|
CONF_REQUEST_TIMEOUT,
|
|
CONF_OPENAI_API_KEY,
|
|
CONF_REMOTE_USE_CHAT_ENDPOINT,
|
|
CONF_OLLAMA_KEEP_ALIVE_MIN,
|
|
CONF_OLLAMA_JSON_MODE,
|
|
CONF_CONTEXT_LENGTH,
|
|
DEFAULT_MAX_TOKENS,
|
|
DEFAULT_TEMPERATURE,
|
|
DEFAULT_TOP_K,
|
|
DEFAULT_TOP_P,
|
|
DEFAULT_TYPICAL_P,
|
|
DEFAULT_REQUEST_TIMEOUT,
|
|
DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
|
|
DEFAULT_OLLAMA_KEEP_ALIVE_MIN,
|
|
DEFAULT_OLLAMA_JSON_MODE,
|
|
DEFAULT_CONTEXT_LENGTH,
|
|
)
|
|
|
|
from custom_components.llama_conversation.conversation import LocalLLMAgent, TextGenerationResult
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
class OllamaAPIAgent(LocalLLMAgent):
|
|
api_host: str
|
|
api_key: Optional[str]
|
|
model_name: Optional[str]
|
|
|
|
async def _async_load_model(self, entry: ConfigEntry) -> None:
|
|
self.api_host = format_url(
|
|
hostname=entry.data[CONF_HOST],
|
|
port=entry.data[CONF_PORT],
|
|
ssl=entry.data[CONF_SSL],
|
|
path=""
|
|
)
|
|
self.api_key = entry.data.get(CONF_OPENAI_API_KEY)
|
|
self.model_name = entry.data.get(CONF_CHAT_MODEL)
|
|
|
|
# ollama handles loading for us so just make sure the model is available
|
|
try:
|
|
headers = {}
|
|
session = async_get_clientsession(self.hass)
|
|
if self.api_key:
|
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
|
|
async with session.get(
|
|
f"{self.api_host}/api/tags",
|
|
headers=headers,
|
|
) as response:
|
|
response.raise_for_status()
|
|
currently_downloaded_result = await response.json()
|
|
|
|
except Exception as ex:
|
|
_LOGGER.debug("Connection error was: %s", repr(ex))
|
|
raise ConfigEntryNotReady("There was a problem connecting to the remote server") from ex
|
|
|
|
model_names = [ x["name"] for x in currently_downloaded_result["models"] ]
|
|
if ":" in self.model_name:
|
|
if not any([ name == self.model_name for name in model_names]):
|
|
raise ConfigEntryNotReady(f"Ollama server does not have the provided model: {self.model_name}")
|
|
elif not any([ name.split(":")[0] == self.model_name for name in model_names ]):
|
|
raise ConfigEntryNotReady(f"Ollama server does not have the provided model: {self.model_name}")
|
|
|
|
def _chat_completion_params(self, conversation: List[Dict[str, str]]) -> Tuple[str, Dict]:
|
|
request_params = {}
|
|
|
|
endpoint = "/api/chat"
|
|
request_params["messages"] = [ { "role": x["role"], "content": x["message"] } for x in conversation ]
|
|
|
|
return endpoint, request_params
|
|
|
|
def _completion_params(self, conversation: List[Dict[str, str]]) -> Tuple[str, Dict[str, Any]]:
|
|
request_params = {}
|
|
|
|
endpoint = "/api/generate"
|
|
request_params["prompt"] = self._format_prompt(conversation)
|
|
request_params["raw"] = True # ignore prompt template
|
|
|
|
return endpoint, request_params
|
|
|
|
def _extract_response(self, response_json: Dict) -> TextGenerationResult:
|
|
# TODO: this doesn't work because ollama caches prompts and doesn't always return the full prompt length
|
|
# context_len = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
|
|
# max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
|
# if response_json["prompt_eval_count"] + max_tokens > context_len:
|
|
# self._warn_context_size()
|
|
|
|
if "response" in response_json:
|
|
response = response_json["response"]
|
|
tool_calls = None
|
|
stop_reason = None
|
|
if response_json["done"] not in ["true", True]:
|
|
_LOGGER.warning("Model response did not end on a stop token (unfinished sentence)")
|
|
else:
|
|
response = response_json["message"]["content"]
|
|
tool_calls = response_json["message"].get("tool_calls")
|
|
stop_reason = response_json.get("done_reason")
|
|
|
|
return TextGenerationResult(
|
|
response=response, tool_calls=tool_calls, stop_reason=stop_reason, response_streamed=True
|
|
)
|
|
|
|
async def _async_generate(self, conversation: List[Dict[str, str]]) -> TextGenerationResult:
|
|
context_length = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
|
|
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
|
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
|
|
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
|
|
top_k = self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K)
|
|
typical_p = self.entry.options.get(CONF_TYPICAL_P, DEFAULT_TYPICAL_P)
|
|
timeout = self.entry.options.get(CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT)
|
|
keep_alive = self.entry.options.get(CONF_OLLAMA_KEEP_ALIVE_MIN, DEFAULT_OLLAMA_KEEP_ALIVE_MIN)
|
|
use_chat_api = self.entry.options.get(CONF_REMOTE_USE_CHAT_ENDPOINT, DEFAULT_REMOTE_USE_CHAT_ENDPOINT)
|
|
json_mode = self.entry.options.get(CONF_OLLAMA_JSON_MODE, DEFAULT_OLLAMA_JSON_MODE)
|
|
|
|
request_params = {
|
|
"model": self.model_name,
|
|
"stream": True,
|
|
"keep_alive": f"{keep_alive}m", # prevent ollama from unloading the model
|
|
"options": {
|
|
"num_ctx": context_length,
|
|
"top_p": top_p,
|
|
"top_k": top_k,
|
|
"typical_p": typical_p,
|
|
"temperature": temperature,
|
|
"num_preDict": max_tokens,
|
|
}
|
|
}
|
|
|
|
if json_mode:
|
|
request_params["format"] = "json"
|
|
|
|
if use_chat_api:
|
|
endpoint, additional_params = self._chat_completion_params(conversation)
|
|
else:
|
|
endpoint, additional_params = self._completion_params(conversation)
|
|
|
|
request_params.update(additional_params)
|
|
|
|
headers = {}
|
|
if self.api_key:
|
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
|
|
session = async_get_clientsession(self.hass)
|
|
response = None
|
|
result = TextGenerationResult(
|
|
response="", response_streamed=True
|
|
)
|
|
try:
|
|
async with session.post(
|
|
f"{self.api_host}{endpoint}",
|
|
json=request_params,
|
|
timeout=timeout,
|
|
headers=headers
|
|
) as response:
|
|
response.raise_for_status()
|
|
|
|
while True:
|
|
chunk = await response.content.readline()
|
|
if not chunk:
|
|
break
|
|
|
|
parsed_chunk = self._extract_response(json.loads(chunk))
|
|
result.response += parsed_chunk.response
|
|
result.stop_reason = parsed_chunk.stop_reason
|
|
result.tool_calls = parsed_chunk.tool_calls
|
|
except asyncio.TimeoutError:
|
|
return TextGenerationResult(raise_error=True, error_msg="The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities.")
|
|
except aiohttp.ClientError as err:
|
|
_LOGGER.debug(f"Err was: {err}")
|
|
_LOGGER.debug(f"Request was: {request_params}")
|
|
_LOGGER.debug(f"Result was: {response}")
|
|
return TextGenerationResult(raise_error=True, error_msg=f"Failed to communicate with the API! {err}")
|
|
|
|
_LOGGER.debug(result)
|
|
|
|
return result
|