ollama and generic openai theoretically work

This commit is contained in:
Alex O'Connell
2025-09-15 22:10:25 -04:00
parent 61d52ae4d1
commit 05e3ceff7b
4 changed files with 66 additions and 83 deletions

View File

@@ -89,7 +89,10 @@ class BaseOpenAICompatibleAPIAgent(LocalLLMAgent):
response.raise_for_status()
if stream:
async for line_bytes in response.content:
chunk = line_bytes.decode("utf-8").strip()
chunk = line_bytes.decode("utf-8").strip().removeprefix("data: ")
if not chunk.strip():
break
yield self._extract_response(json.loads(chunk), llm_api, user_input)
else:
response_json = await response.json()
@@ -119,13 +122,10 @@ class GenericOpenAIAPIAgent(BaseOpenAICompatibleAPIAgent):
if "tool_calls" in choice["delta"]:
tool_calls = []
for call in choice["delta"]["tool_calls"]:
success, tool_args, to_say = parse_raw_tool_call(
call["function"]["arguments"],
call["function"]["name"],
call["id"],
llm_api, user_input)
tool_args, to_say = parse_raw_tool_call(
call["function"], llm_api)
if success and tool_args:
if tool_args:
tool_calls.append(tool_args)
if to_say:
@@ -139,6 +139,8 @@ class GenericOpenAIAPIAgent(BaseOpenAICompatibleAPIAgent):
if choice["finish_reason"] == "length" or choice["finish_reason"] == "content_filter":
_LOGGER.warning("Model response did not end on a stop token (unfinished sentence)")
_LOGGER.debug("Model chunk '%s'", response_text)
return TextGenerationResult(
response=response_text,
stop_reason=choice["finish_reason"],

View File

@@ -361,7 +361,7 @@ class LlamaCppAgent(LocalLLMAgent):
refresh_delay = self.entry.options.get(CONF_PROMPT_CACHING_INTERVAL, DEFAULT_PROMPT_CACHING_INTERVAL)
async_call_later(self.hass, float(refresh_delay), refresh_if_requested)
async def _async_generate_completion(self, llm_api: llm.APIInstance | None, chat_completion: Iterator[CreateChatCompletionStreamResponse]) -> AsyncGenerator[TextGenerationResult, None]:
async def _async_parse_completion(self, llm_api: llm.APIInstance | None, chat_completion: Iterator[CreateChatCompletionStreamResponse]) -> AsyncGenerator[TextGenerationResult, None]:
think_prefix = self.entry.options.get(CONF_THINKING_PREFIX, DEFAULT_THINKING_PREFIX)
think_suffix = self.entry.options.get(CONF_THINKING_SUFFIX, DEFAULT_THINKING_SUFFIX)
tool_prefix = self.entry.options.get(CONF_TOOL_CALL_PREFIX, DEFAULT_TOOL_CALL_PREFIX)
@@ -377,7 +377,7 @@ class LlamaCppAgent(LocalLLMAgent):
in_thinking = False
in_tool_call = False
tool_content = ""
last_5_tokens = []
last_5_tokens = [] # FIXME: this still returns the first few tokens of the tool call if the prefix is split across chunks
while chunk := await self.hass.async_add_executor_job(next_token):
content = chunk["choices"][0]["delta"].get("content")
tool_calls = chunk["choices"][0]["delta"].get("tool_calls")
@@ -409,6 +409,7 @@ class LlamaCppAgent(LocalLLMAgent):
elif tool_prefix in potential_block and not in_tool_call:
in_tool_call = True
last_5_tokens.clear()
elif tool_suffix in potential_block and in_tool_call:
in_tool_call = False
_LOGGER.debug("Tool content: %s", tool_content)
@@ -467,7 +468,7 @@ class LlamaCppAgent(LocalLLMAgent):
_LOGGER.debug(f"Generating completion with {len(messages)} messages and {len(tools) if tools else 0} tools...")
return self._async_generate_completion(llm_api, self.llm.create_chat_completion(
return self._async_parse_completion(llm_api, self.llm.create_chat_completion(
messages,
tools=tools,
temperature=temperature,

View File

@@ -6,15 +6,16 @@ import aiohttp
import asyncio
import json
import logging
from typing import Optional, Tuple, Dict, List, Any
from typing import Optional, Tuple, Dict, List, Any, AsyncGenerator
from homeassistant.components import conversation as conversation
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL
from homeassistant.exceptions import ConfigEntryNotReady
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers import llm
from custom_components.llama_conversation.utils import format_url
from custom_components.llama_conversation.utils import format_url, get_oai_formatted_messages, get_oai_formatted_tools
from custom_components.llama_conversation.const import (
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
@@ -83,23 +84,6 @@ class OllamaAPIAgent(LocalLLMAgent):
elif not any([ name.split(":")[0] == self.model_name for name in model_names ]):
raise ConfigEntryNotReady(f"Ollama server does not have the provided model: {self.model_name}")
def _chat_completion_params(self, conversation: List[Dict[str, str]]) -> Tuple[str, Dict]:
request_params = {}
endpoint = "/api/chat"
request_params["messages"] = [ { "role": x["role"], "content": x["message"] } for x in conversation ]
return endpoint, request_params
def _completion_params(self, conversation: List[Dict[str, str]]) -> Tuple[str, Dict[str, Any]]:
request_params = {}
endpoint = "/api/generate"
request_params["prompt"] = self._format_prompt(conversation)
request_params["raw"] = True # ignore prompt template
return endpoint, request_params
def _extract_response(self, response_json: Dict) -> TextGenerationResult:
# TODO: this doesn't work because ollama caches prompts and doesn't always return the full prompt length
# context_len = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
@@ -121,52 +105,11 @@ class OllamaAPIAgent(LocalLLMAgent):
return TextGenerationResult(
response=response, tool_calls=tool_calls, stop_reason=stop_reason, response_streamed=True
)
async def _async_generate(self, conversation: List[Dict[str, str]]) -> TextGenerationResult:
context_length = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
top_k = self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K)
typical_p = self.entry.options.get(CONF_TYPICAL_P, DEFAULT_TYPICAL_P)
timeout = self.entry.options.get(CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT)
keep_alive = self.entry.options.get(CONF_OLLAMA_KEEP_ALIVE_MIN, DEFAULT_OLLAMA_KEEP_ALIVE_MIN)
use_chat_api = self.entry.options.get(CONF_REMOTE_USE_CHAT_ENDPOINT, DEFAULT_REMOTE_USE_CHAT_ENDPOINT)
json_mode = self.entry.options.get(CONF_OLLAMA_JSON_MODE, DEFAULT_OLLAMA_JSON_MODE)
request_params = {
"model": self.model_name,
"stream": True,
"keep_alive": f"{keep_alive}m", # prevent ollama from unloading the model
"options": {
"num_ctx": context_length,
"top_p": top_p,
"top_k": top_k,
"typical_p": typical_p,
"temperature": temperature,
"num_preDict": max_tokens,
}
}
if json_mode:
request_params["format"] = "json"
if use_chat_api:
endpoint, additional_params = self._chat_completion_params(conversation)
else:
endpoint, additional_params = self._completion_params(conversation)
request_params.update(additional_params)
headers = {}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
async def _async_generate_with_parameters(self, endpoint: str, request_params: dict[str, Any], headers: dict[str, Any], timeout: int) -> AsyncGenerator[TextGenerationResult, None]:
session = async_get_clientsession(self.hass)
response = None
result = TextGenerationResult(
response="", response_streamed=True
)
try:
async with session.post(
f"{self.api_host}{endpoint}",
@@ -181,18 +124,51 @@ class OllamaAPIAgent(LocalLLMAgent):
if not chunk:
break
parsed_chunk = self._extract_response(json.loads(chunk))
result.response += parsed_chunk.response
result.stop_reason = parsed_chunk.stop_reason
result.tool_calls = parsed_chunk.tool_calls
yield self._extract_response(json.loads(chunk))
except asyncio.TimeoutError:
return TextGenerationResult(raise_error=True, error_msg="The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities.")
yield TextGenerationResult(raise_error=True, error_msg="The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities.")
except aiohttp.ClientError as err:
_LOGGER.debug(f"Err was: {err}")
_LOGGER.debug(f"Request was: {request_params}")
_LOGGER.debug(f"Result was: {response}")
return TextGenerationResult(raise_error=True, error_msg=f"Failed to communicate with the API! {err}")
yield TextGenerationResult(raise_error=True, error_msg=f"Failed to communicate with the API! {err}")
_LOGGER.debug(result)
def _generate_stream(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None, user_input: conversation.ConversationInput) -> AsyncGenerator[TextGenerationResult, None]:
context_length = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
top_k = self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K)
typical_p = self.entry.options.get(CONF_TYPICAL_P, DEFAULT_TYPICAL_P)
timeout = self.entry.options.get(CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT)
keep_alive = self.entry.options.get(CONF_OLLAMA_KEEP_ALIVE_MIN, DEFAULT_OLLAMA_KEEP_ALIVE_MIN)
json_mode = self.entry.options.get(CONF_OLLAMA_JSON_MODE, DEFAULT_OLLAMA_JSON_MODE)
return result
request_params = {
"model": self.model_name,
"stream": True,
"keep_alive": f"{keep_alive}m", # prevent ollama from unloading the model
"options": {
"num_ctx": context_length,
"top_p": top_p,
"top_k": top_k,
"typical_p": typical_p,
"temperature": temperature,
"num_predict": max_tokens,
},
}
if json_mode:
request_params["format"] = "json"
if llm_api:
request_params["tools"] = get_oai_formatted_tools(llm_api, self._async_get_all_exposed_domains())
endpoint = "/api/chat"
request_params["messages"] = get_oai_formatted_messages(conversation)
headers = {}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
return self._async_generate_with_parameters(endpoint, request_params, headers, timeout)

View File

@@ -263,6 +263,7 @@ def get_oai_formatted_tools(llm_api: llm.APIInstance, domains: list[str]) -> Lis
"type": "function",
"function": {
"name": tool["name"],
"description": f"Call the Home Assistant service '{tool['name']}'",
"parameters": convert(tool["arguments"], custom_serializer=llm_api.custom_serializer)
}
} for tool in get_home_llm_tools(llm_api, domains) ]
@@ -360,8 +361,11 @@ def get_home_llm_tools(llm_api: llm.APIInstance, domains: list[str]) -> List[Dic
return tools
def parse_raw_tool_call(raw_block: str, llm_api: llm.APIInstance) -> tuple[llm.ToolInput | None, str | None]:
parsed_tool_call: dict = json.loads(raw_block)
def parse_raw_tool_call(raw_block: str | dict, llm_api: llm.APIInstance) -> tuple[llm.ToolInput | None, str | None]:
if isinstance(raw_block, dict):
parsed_tool_call = raw_block
else:
parsed_tool_call: dict = json.loads(raw_block)
if llm_api.api.id == HOME_LLM_API_ID:
schema_to_validate = vol.Schema({