mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-08 21:28:05 -05:00
ollama and generic openai theoretically work
This commit is contained in:
@@ -89,7 +89,10 @@ class BaseOpenAICompatibleAPIAgent(LocalLLMAgent):
|
||||
response.raise_for_status()
|
||||
if stream:
|
||||
async for line_bytes in response.content:
|
||||
chunk = line_bytes.decode("utf-8").strip()
|
||||
chunk = line_bytes.decode("utf-8").strip().removeprefix("data: ")
|
||||
if not chunk.strip():
|
||||
break
|
||||
|
||||
yield self._extract_response(json.loads(chunk), llm_api, user_input)
|
||||
else:
|
||||
response_json = await response.json()
|
||||
@@ -119,13 +122,10 @@ class GenericOpenAIAPIAgent(BaseOpenAICompatibleAPIAgent):
|
||||
if "tool_calls" in choice["delta"]:
|
||||
tool_calls = []
|
||||
for call in choice["delta"]["tool_calls"]:
|
||||
success, tool_args, to_say = parse_raw_tool_call(
|
||||
call["function"]["arguments"],
|
||||
call["function"]["name"],
|
||||
call["id"],
|
||||
llm_api, user_input)
|
||||
tool_args, to_say = parse_raw_tool_call(
|
||||
call["function"], llm_api)
|
||||
|
||||
if success and tool_args:
|
||||
if tool_args:
|
||||
tool_calls.append(tool_args)
|
||||
|
||||
if to_say:
|
||||
@@ -139,6 +139,8 @@ class GenericOpenAIAPIAgent(BaseOpenAICompatibleAPIAgent):
|
||||
if choice["finish_reason"] == "length" or choice["finish_reason"] == "content_filter":
|
||||
_LOGGER.warning("Model response did not end on a stop token (unfinished sentence)")
|
||||
|
||||
_LOGGER.debug("Model chunk '%s'", response_text)
|
||||
|
||||
return TextGenerationResult(
|
||||
response=response_text,
|
||||
stop_reason=choice["finish_reason"],
|
||||
|
||||
@@ -361,7 +361,7 @@ class LlamaCppAgent(LocalLLMAgent):
|
||||
refresh_delay = self.entry.options.get(CONF_PROMPT_CACHING_INTERVAL, DEFAULT_PROMPT_CACHING_INTERVAL)
|
||||
async_call_later(self.hass, float(refresh_delay), refresh_if_requested)
|
||||
|
||||
async def _async_generate_completion(self, llm_api: llm.APIInstance | None, chat_completion: Iterator[CreateChatCompletionStreamResponse]) -> AsyncGenerator[TextGenerationResult, None]:
|
||||
async def _async_parse_completion(self, llm_api: llm.APIInstance | None, chat_completion: Iterator[CreateChatCompletionStreamResponse]) -> AsyncGenerator[TextGenerationResult, None]:
|
||||
think_prefix = self.entry.options.get(CONF_THINKING_PREFIX, DEFAULT_THINKING_PREFIX)
|
||||
think_suffix = self.entry.options.get(CONF_THINKING_SUFFIX, DEFAULT_THINKING_SUFFIX)
|
||||
tool_prefix = self.entry.options.get(CONF_TOOL_CALL_PREFIX, DEFAULT_TOOL_CALL_PREFIX)
|
||||
@@ -377,7 +377,7 @@ class LlamaCppAgent(LocalLLMAgent):
|
||||
in_thinking = False
|
||||
in_tool_call = False
|
||||
tool_content = ""
|
||||
last_5_tokens = []
|
||||
last_5_tokens = [] # FIXME: this still returns the first few tokens of the tool call if the prefix is split across chunks
|
||||
while chunk := await self.hass.async_add_executor_job(next_token):
|
||||
content = chunk["choices"][0]["delta"].get("content")
|
||||
tool_calls = chunk["choices"][0]["delta"].get("tool_calls")
|
||||
@@ -409,6 +409,7 @@ class LlamaCppAgent(LocalLLMAgent):
|
||||
elif tool_prefix in potential_block and not in_tool_call:
|
||||
in_tool_call = True
|
||||
last_5_tokens.clear()
|
||||
|
||||
elif tool_suffix in potential_block and in_tool_call:
|
||||
in_tool_call = False
|
||||
_LOGGER.debug("Tool content: %s", tool_content)
|
||||
@@ -467,7 +468,7 @@ class LlamaCppAgent(LocalLLMAgent):
|
||||
|
||||
_LOGGER.debug(f"Generating completion with {len(messages)} messages and {len(tools) if tools else 0} tools...")
|
||||
|
||||
return self._async_generate_completion(llm_api, self.llm.create_chat_completion(
|
||||
return self._async_parse_completion(llm_api, self.llm.create_chat_completion(
|
||||
messages,
|
||||
tools=tools,
|
||||
temperature=temperature,
|
||||
|
||||
@@ -6,15 +6,16 @@ import aiohttp
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Tuple, Dict, List, Any
|
||||
from typing import Optional, Tuple, Dict, List, Any, AsyncGenerator
|
||||
|
||||
from homeassistant.components import conversation as conversation
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL
|
||||
from homeassistant.exceptions import ConfigEntryNotReady
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
from homeassistant.helpers import llm
|
||||
|
||||
from custom_components.llama_conversation.utils import format_url
|
||||
from custom_components.llama_conversation.utils import format_url, get_oai_formatted_messages, get_oai_formatted_tools
|
||||
from custom_components.llama_conversation.const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_MAX_TOKENS,
|
||||
@@ -83,23 +84,6 @@ class OllamaAPIAgent(LocalLLMAgent):
|
||||
elif not any([ name.split(":")[0] == self.model_name for name in model_names ]):
|
||||
raise ConfigEntryNotReady(f"Ollama server does not have the provided model: {self.model_name}")
|
||||
|
||||
def _chat_completion_params(self, conversation: List[Dict[str, str]]) -> Tuple[str, Dict]:
|
||||
request_params = {}
|
||||
|
||||
endpoint = "/api/chat"
|
||||
request_params["messages"] = [ { "role": x["role"], "content": x["message"] } for x in conversation ]
|
||||
|
||||
return endpoint, request_params
|
||||
|
||||
def _completion_params(self, conversation: List[Dict[str, str]]) -> Tuple[str, Dict[str, Any]]:
|
||||
request_params = {}
|
||||
|
||||
endpoint = "/api/generate"
|
||||
request_params["prompt"] = self._format_prompt(conversation)
|
||||
request_params["raw"] = True # ignore prompt template
|
||||
|
||||
return endpoint, request_params
|
||||
|
||||
def _extract_response(self, response_json: Dict) -> TextGenerationResult:
|
||||
# TODO: this doesn't work because ollama caches prompts and doesn't always return the full prompt length
|
||||
# context_len = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
|
||||
@@ -121,52 +105,11 @@ class OllamaAPIAgent(LocalLLMAgent):
|
||||
return TextGenerationResult(
|
||||
response=response, tool_calls=tool_calls, stop_reason=stop_reason, response_streamed=True
|
||||
)
|
||||
|
||||
async def _async_generate(self, conversation: List[Dict[str, str]]) -> TextGenerationResult:
|
||||
context_length = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
|
||||
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
||||
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
|
||||
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
|
||||
top_k = self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K)
|
||||
typical_p = self.entry.options.get(CONF_TYPICAL_P, DEFAULT_TYPICAL_P)
|
||||
timeout = self.entry.options.get(CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT)
|
||||
keep_alive = self.entry.options.get(CONF_OLLAMA_KEEP_ALIVE_MIN, DEFAULT_OLLAMA_KEEP_ALIVE_MIN)
|
||||
use_chat_api = self.entry.options.get(CONF_REMOTE_USE_CHAT_ENDPOINT, DEFAULT_REMOTE_USE_CHAT_ENDPOINT)
|
||||
json_mode = self.entry.options.get(CONF_OLLAMA_JSON_MODE, DEFAULT_OLLAMA_JSON_MODE)
|
||||
|
||||
request_params = {
|
||||
"model": self.model_name,
|
||||
"stream": True,
|
||||
"keep_alive": f"{keep_alive}m", # prevent ollama from unloading the model
|
||||
"options": {
|
||||
"num_ctx": context_length,
|
||||
"top_p": top_p,
|
||||
"top_k": top_k,
|
||||
"typical_p": typical_p,
|
||||
"temperature": temperature,
|
||||
"num_preDict": max_tokens,
|
||||
}
|
||||
}
|
||||
|
||||
if json_mode:
|
||||
request_params["format"] = "json"
|
||||
|
||||
if use_chat_api:
|
||||
endpoint, additional_params = self._chat_completion_params(conversation)
|
||||
else:
|
||||
endpoint, additional_params = self._completion_params(conversation)
|
||||
|
||||
request_params.update(additional_params)
|
||||
|
||||
headers = {}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
|
||||
async def _async_generate_with_parameters(self, endpoint: str, request_params: dict[str, Any], headers: dict[str, Any], timeout: int) -> AsyncGenerator[TextGenerationResult, None]:
|
||||
session = async_get_clientsession(self.hass)
|
||||
response = None
|
||||
result = TextGenerationResult(
|
||||
response="", response_streamed=True
|
||||
)
|
||||
|
||||
try:
|
||||
async with session.post(
|
||||
f"{self.api_host}{endpoint}",
|
||||
@@ -181,18 +124,51 @@ class OllamaAPIAgent(LocalLLMAgent):
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
parsed_chunk = self._extract_response(json.loads(chunk))
|
||||
result.response += parsed_chunk.response
|
||||
result.stop_reason = parsed_chunk.stop_reason
|
||||
result.tool_calls = parsed_chunk.tool_calls
|
||||
yield self._extract_response(json.loads(chunk))
|
||||
except asyncio.TimeoutError:
|
||||
return TextGenerationResult(raise_error=True, error_msg="The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities.")
|
||||
yield TextGenerationResult(raise_error=True, error_msg="The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities.")
|
||||
except aiohttp.ClientError as err:
|
||||
_LOGGER.debug(f"Err was: {err}")
|
||||
_LOGGER.debug(f"Request was: {request_params}")
|
||||
_LOGGER.debug(f"Result was: {response}")
|
||||
return TextGenerationResult(raise_error=True, error_msg=f"Failed to communicate with the API! {err}")
|
||||
yield TextGenerationResult(raise_error=True, error_msg=f"Failed to communicate with the API! {err}")
|
||||
|
||||
_LOGGER.debug(result)
|
||||
def _generate_stream(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None, user_input: conversation.ConversationInput) -> AsyncGenerator[TextGenerationResult, None]:
|
||||
context_length = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
|
||||
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
||||
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
|
||||
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
|
||||
top_k = self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K)
|
||||
typical_p = self.entry.options.get(CONF_TYPICAL_P, DEFAULT_TYPICAL_P)
|
||||
timeout = self.entry.options.get(CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT)
|
||||
keep_alive = self.entry.options.get(CONF_OLLAMA_KEEP_ALIVE_MIN, DEFAULT_OLLAMA_KEEP_ALIVE_MIN)
|
||||
json_mode = self.entry.options.get(CONF_OLLAMA_JSON_MODE, DEFAULT_OLLAMA_JSON_MODE)
|
||||
|
||||
return result
|
||||
request_params = {
|
||||
"model": self.model_name,
|
||||
"stream": True,
|
||||
"keep_alive": f"{keep_alive}m", # prevent ollama from unloading the model
|
||||
"options": {
|
||||
"num_ctx": context_length,
|
||||
"top_p": top_p,
|
||||
"top_k": top_k,
|
||||
"typical_p": typical_p,
|
||||
"temperature": temperature,
|
||||
"num_predict": max_tokens,
|
||||
},
|
||||
}
|
||||
|
||||
if json_mode:
|
||||
request_params["format"] = "json"
|
||||
|
||||
if llm_api:
|
||||
request_params["tools"] = get_oai_formatted_tools(llm_api, self._async_get_all_exposed_domains())
|
||||
|
||||
endpoint = "/api/chat"
|
||||
request_params["messages"] = get_oai_formatted_messages(conversation)
|
||||
|
||||
headers = {}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
return self._async_generate_with_parameters(endpoint, request_params, headers, timeout)
|
||||
|
||||
@@ -263,6 +263,7 @@ def get_oai_formatted_tools(llm_api: llm.APIInstance, domains: list[str]) -> Lis
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool["name"],
|
||||
"description": f"Call the Home Assistant service '{tool['name']}'",
|
||||
"parameters": convert(tool["arguments"], custom_serializer=llm_api.custom_serializer)
|
||||
}
|
||||
} for tool in get_home_llm_tools(llm_api, domains) ]
|
||||
@@ -360,8 +361,11 @@ def get_home_llm_tools(llm_api: llm.APIInstance, domains: list[str]) -> List[Dic
|
||||
|
||||
return tools
|
||||
|
||||
def parse_raw_tool_call(raw_block: str, llm_api: llm.APIInstance) -> tuple[llm.ToolInput | None, str | None]:
|
||||
parsed_tool_call: dict = json.loads(raw_block)
|
||||
def parse_raw_tool_call(raw_block: str | dict, llm_api: llm.APIInstance) -> tuple[llm.ToolInput | None, str | None]:
|
||||
if isinstance(raw_block, dict):
|
||||
parsed_tool_call = raw_block
|
||||
else:
|
||||
parsed_tool_call: dict = json.loads(raw_block)
|
||||
|
||||
if llm_api.api.id == HOME_LLM_API_ID:
|
||||
schema_to_validate = vol.Schema({
|
||||
|
||||
Reference in New Issue
Block a user