diff --git a/custom_components/llama_conversation/backends/generic_openai.py b/custom_components/llama_conversation/backends/generic_openai.py index 45c326a..f76ff15 100644 --- a/custom_components/llama_conversation/backends/generic_openai.py +++ b/custom_components/llama_conversation/backends/generic_openai.py @@ -89,7 +89,10 @@ class BaseOpenAICompatibleAPIAgent(LocalLLMAgent): response.raise_for_status() if stream: async for line_bytes in response.content: - chunk = line_bytes.decode("utf-8").strip() + chunk = line_bytes.decode("utf-8").strip().removeprefix("data: ") + if not chunk.strip(): + break + yield self._extract_response(json.loads(chunk), llm_api, user_input) else: response_json = await response.json() @@ -119,13 +122,10 @@ class GenericOpenAIAPIAgent(BaseOpenAICompatibleAPIAgent): if "tool_calls" in choice["delta"]: tool_calls = [] for call in choice["delta"]["tool_calls"]: - success, tool_args, to_say = parse_raw_tool_call( - call["function"]["arguments"], - call["function"]["name"], - call["id"], - llm_api, user_input) + tool_args, to_say = parse_raw_tool_call( + call["function"], llm_api) - if success and tool_args: + if tool_args: tool_calls.append(tool_args) if to_say: @@ -139,6 +139,8 @@ class GenericOpenAIAPIAgent(BaseOpenAICompatibleAPIAgent): if choice["finish_reason"] == "length" or choice["finish_reason"] == "content_filter": _LOGGER.warning("Model response did not end on a stop token (unfinished sentence)") + _LOGGER.debug("Model chunk '%s'", response_text) + return TextGenerationResult( response=response_text, stop_reason=choice["finish_reason"], diff --git a/custom_components/llama_conversation/backends/llamacpp.py b/custom_components/llama_conversation/backends/llamacpp.py index 6e554a3..8189d23 100644 --- a/custom_components/llama_conversation/backends/llamacpp.py +++ b/custom_components/llama_conversation/backends/llamacpp.py @@ -361,7 +361,7 @@ class LlamaCppAgent(LocalLLMAgent): refresh_delay = self.entry.options.get(CONF_PROMPT_CACHING_INTERVAL, DEFAULT_PROMPT_CACHING_INTERVAL) async_call_later(self.hass, float(refresh_delay), refresh_if_requested) - async def _async_generate_completion(self, llm_api: llm.APIInstance | None, chat_completion: Iterator[CreateChatCompletionStreamResponse]) -> AsyncGenerator[TextGenerationResult, None]: + async def _async_parse_completion(self, llm_api: llm.APIInstance | None, chat_completion: Iterator[CreateChatCompletionStreamResponse]) -> AsyncGenerator[TextGenerationResult, None]: think_prefix = self.entry.options.get(CONF_THINKING_PREFIX, DEFAULT_THINKING_PREFIX) think_suffix = self.entry.options.get(CONF_THINKING_SUFFIX, DEFAULT_THINKING_SUFFIX) tool_prefix = self.entry.options.get(CONF_TOOL_CALL_PREFIX, DEFAULT_TOOL_CALL_PREFIX) @@ -377,7 +377,7 @@ class LlamaCppAgent(LocalLLMAgent): in_thinking = False in_tool_call = False tool_content = "" - last_5_tokens = [] + last_5_tokens = [] # FIXME: this still returns the first few tokens of the tool call if the prefix is split across chunks while chunk := await self.hass.async_add_executor_job(next_token): content = chunk["choices"][0]["delta"].get("content") tool_calls = chunk["choices"][0]["delta"].get("tool_calls") @@ -409,6 +409,7 @@ class LlamaCppAgent(LocalLLMAgent): elif tool_prefix in potential_block and not in_tool_call: in_tool_call = True last_5_tokens.clear() + elif tool_suffix in potential_block and in_tool_call: in_tool_call = False _LOGGER.debug("Tool content: %s", tool_content) @@ -467,7 +468,7 @@ class LlamaCppAgent(LocalLLMAgent): _LOGGER.debug(f"Generating completion with {len(messages)} messages and {len(tools) if tools else 0} tools...") - return self._async_generate_completion(llm_api, self.llm.create_chat_completion( + return self._async_parse_completion(llm_api, self.llm.create_chat_completion( messages, tools=tools, temperature=temperature, diff --git a/custom_components/llama_conversation/backends/ollama.py b/custom_components/llama_conversation/backends/ollama.py index f76615a..2d79dd2 100644 --- a/custom_components/llama_conversation/backends/ollama.py +++ b/custom_components/llama_conversation/backends/ollama.py @@ -6,15 +6,16 @@ import aiohttp import asyncio import json import logging -from typing import Optional, Tuple, Dict, List, Any +from typing import Optional, Tuple, Dict, List, Any, AsyncGenerator from homeassistant.components import conversation as conversation from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.helpers.aiohttp_client import async_get_clientsession +from homeassistant.helpers import llm -from custom_components.llama_conversation.utils import format_url +from custom_components.llama_conversation.utils import format_url, get_oai_formatted_messages, get_oai_formatted_tools from custom_components.llama_conversation.const import ( CONF_CHAT_MODEL, CONF_MAX_TOKENS, @@ -83,23 +84,6 @@ class OllamaAPIAgent(LocalLLMAgent): elif not any([ name.split(":")[0] == self.model_name for name in model_names ]): raise ConfigEntryNotReady(f"Ollama server does not have the provided model: {self.model_name}") - def _chat_completion_params(self, conversation: List[Dict[str, str]]) -> Tuple[str, Dict]: - request_params = {} - - endpoint = "/api/chat" - request_params["messages"] = [ { "role": x["role"], "content": x["message"] } for x in conversation ] - - return endpoint, request_params - - def _completion_params(self, conversation: List[Dict[str, str]]) -> Tuple[str, Dict[str, Any]]: - request_params = {} - - endpoint = "/api/generate" - request_params["prompt"] = self._format_prompt(conversation) - request_params["raw"] = True # ignore prompt template - - return endpoint, request_params - def _extract_response(self, response_json: Dict) -> TextGenerationResult: # TODO: this doesn't work because ollama caches prompts and doesn't always return the full prompt length # context_len = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH) @@ -121,52 +105,11 @@ class OllamaAPIAgent(LocalLLMAgent): return TextGenerationResult( response=response, tool_calls=tool_calls, stop_reason=stop_reason, response_streamed=True ) - - async def _async_generate(self, conversation: List[Dict[str, str]]) -> TextGenerationResult: - context_length = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH) - max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS) - temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE) - top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P) - top_k = self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K) - typical_p = self.entry.options.get(CONF_TYPICAL_P, DEFAULT_TYPICAL_P) - timeout = self.entry.options.get(CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT) - keep_alive = self.entry.options.get(CONF_OLLAMA_KEEP_ALIVE_MIN, DEFAULT_OLLAMA_KEEP_ALIVE_MIN) - use_chat_api = self.entry.options.get(CONF_REMOTE_USE_CHAT_ENDPOINT, DEFAULT_REMOTE_USE_CHAT_ENDPOINT) - json_mode = self.entry.options.get(CONF_OLLAMA_JSON_MODE, DEFAULT_OLLAMA_JSON_MODE) - - request_params = { - "model": self.model_name, - "stream": True, - "keep_alive": f"{keep_alive}m", # prevent ollama from unloading the model - "options": { - "num_ctx": context_length, - "top_p": top_p, - "top_k": top_k, - "typical_p": typical_p, - "temperature": temperature, - "num_preDict": max_tokens, - } - } - - if json_mode: - request_params["format"] = "json" - - if use_chat_api: - endpoint, additional_params = self._chat_completion_params(conversation) - else: - endpoint, additional_params = self._completion_params(conversation) - - request_params.update(additional_params) - - headers = {} - if self.api_key: - headers["Authorization"] = f"Bearer {self.api_key}" - + + async def _async_generate_with_parameters(self, endpoint: str, request_params: dict[str, Any], headers: dict[str, Any], timeout: int) -> AsyncGenerator[TextGenerationResult, None]: session = async_get_clientsession(self.hass) response = None - result = TextGenerationResult( - response="", response_streamed=True - ) + try: async with session.post( f"{self.api_host}{endpoint}", @@ -181,18 +124,51 @@ class OllamaAPIAgent(LocalLLMAgent): if not chunk: break - parsed_chunk = self._extract_response(json.loads(chunk)) - result.response += parsed_chunk.response - result.stop_reason = parsed_chunk.stop_reason - result.tool_calls = parsed_chunk.tool_calls + yield self._extract_response(json.loads(chunk)) except asyncio.TimeoutError: - return TextGenerationResult(raise_error=True, error_msg="The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities.") + yield TextGenerationResult(raise_error=True, error_msg="The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities.") except aiohttp.ClientError as err: _LOGGER.debug(f"Err was: {err}") _LOGGER.debug(f"Request was: {request_params}") _LOGGER.debug(f"Result was: {response}") - return TextGenerationResult(raise_error=True, error_msg=f"Failed to communicate with the API! {err}") + yield TextGenerationResult(raise_error=True, error_msg=f"Failed to communicate with the API! {err}") - _LOGGER.debug(result) + def _generate_stream(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None, user_input: conversation.ConversationInput) -> AsyncGenerator[TextGenerationResult, None]: + context_length = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH) + max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS) + temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE) + top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P) + top_k = self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K) + typical_p = self.entry.options.get(CONF_TYPICAL_P, DEFAULT_TYPICAL_P) + timeout = self.entry.options.get(CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT) + keep_alive = self.entry.options.get(CONF_OLLAMA_KEEP_ALIVE_MIN, DEFAULT_OLLAMA_KEEP_ALIVE_MIN) + json_mode = self.entry.options.get(CONF_OLLAMA_JSON_MODE, DEFAULT_OLLAMA_JSON_MODE) - return result + request_params = { + "model": self.model_name, + "stream": True, + "keep_alive": f"{keep_alive}m", # prevent ollama from unloading the model + "options": { + "num_ctx": context_length, + "top_p": top_p, + "top_k": top_k, + "typical_p": typical_p, + "temperature": temperature, + "num_predict": max_tokens, + }, + } + + if json_mode: + request_params["format"] = "json" + + if llm_api: + request_params["tools"] = get_oai_formatted_tools(llm_api, self._async_get_all_exposed_domains()) + + endpoint = "/api/chat" + request_params["messages"] = get_oai_formatted_messages(conversation) + + headers = {} + if self.api_key: + headers["Authorization"] = f"Bearer {self.api_key}" + + return self._async_generate_with_parameters(endpoint, request_params, headers, timeout) diff --git a/custom_components/llama_conversation/utils.py b/custom_components/llama_conversation/utils.py index 5006118..4cfbfbd 100644 --- a/custom_components/llama_conversation/utils.py +++ b/custom_components/llama_conversation/utils.py @@ -263,6 +263,7 @@ def get_oai_formatted_tools(llm_api: llm.APIInstance, domains: list[str]) -> Lis "type": "function", "function": { "name": tool["name"], + "description": f"Call the Home Assistant service '{tool['name']}'", "parameters": convert(tool["arguments"], custom_serializer=llm_api.custom_serializer) } } for tool in get_home_llm_tools(llm_api, domains) ] @@ -360,8 +361,11 @@ def get_home_llm_tools(llm_api: llm.APIInstance, domains: list[str]) -> List[Dic return tools -def parse_raw_tool_call(raw_block: str, llm_api: llm.APIInstance) -> tuple[llm.ToolInput | None, str | None]: - parsed_tool_call: dict = json.loads(raw_block) +def parse_raw_tool_call(raw_block: str | dict, llm_api: llm.APIInstance) -> tuple[llm.ToolInput | None, str | None]: + if isinstance(raw_block, dict): + parsed_tool_call = raw_block + else: + parsed_tool_call: dict = json.loads(raw_block) if llm_api.api.id == HOME_LLM_API_ID: schema_to_validate = vol.Schema({