diff --git a/custom_components/llama_conversation/__init__.py b/custom_components/llama_conversation/__init__.py index c42a034..95751bd 100644 --- a/custom_components/llama_conversation/__init__.py +++ b/custom_components/llama_conversation/__init__.py @@ -62,7 +62,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: LocalLLMConfigEntry) -> agent_cls = LlamaCppPythonAPIAgent elif backend_type == BACKEND_TYPE_OLLAMA: agent_cls = OllamaAPIAgent - + return agent_cls(hass, entry) # create the agent in an executor job because the constructor calls `open()` @@ -74,7 +74,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: LocalLLMConfigEntry) -> # forward setup to platform to register the entity await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) - + return True @@ -130,12 +130,12 @@ class HassServiceTool(llm.Tool): domain, service = tuple(tool_input.tool_args["service"].split(".")) except ValueError: return { "result": "unknown service" } - + target_device = tool_input.tool_args["target_device"] if domain not in self.ALLOWED_DOMAINS or service not in self.ALLOWED_SERVICES: return { "result": "unknown service" } - + if domain == "script" and service not in ["reload", "turn_on", "turn_off", "toggle"]: return { "result": "unknown service" } @@ -153,12 +153,12 @@ class HassServiceTool(llm.Tool): except Exception: _LOGGER.exception("Failed to execute service for model") return { "result": "failed" } - + return { "result": "success" } class HomeLLMAPI(llm.API): """ - An API that allows calling Home Assistant services to maintain compatibility + An API that allows calling Home Assistant services to maintain compatibility with the older (v3 and older) Home LLM models """ diff --git a/custom_components/llama_conversation/config_flow.py b/custom_components/llama_conversation/config_flow.py index cbe5200..7627d4a 100644 --- a/custom_components/llama_conversation/config_flow.py +++ b/custom_components/llama_conversation/config_flow.py @@ -217,8 +217,8 @@ def STEP_REMOTE_SETUP_DATA_SCHEMA(backend_type: str, *, host=None, port=None, ss extra1, extra2 = ({}, {}) default_port = DEFAULT_PORT - if backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI: - extra2[vol.Optional(CONF_TEXT_GEN_WEBUI_ADMIN_KEY)] = TextSelector(TextSelectorConfig(type="password")) + if backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI: + extra2[vol.Optional(CONF_TEXT_GEN_WEBUI_ADMIN_KEY)] = TextSelector(TextSelectorConfig(type="password")) elif backend_type == BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER: default_port = "8000" elif backend_type == BACKEND_TYPE_OLLAMA: @@ -259,7 +259,7 @@ def STEP_REMOTE_SETUP_DATA_SCHEMA(backend_type: str, *, host=None, port=None, ss class BaseLlamaConversationConfigFlow(FlowHandler, ABC): - """Represent the base config flow for Z-Wave JS.""" + """Represent the base config flow for Local LLM.""" @property @abstractmethod @@ -335,7 +335,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom """Handle the initial step.""" self.model_config = {} self.options = {} - + # make sure the API is registered if not any([x.id == HOME_LLM_API_ID for x in llm.async_get_apis(self.hass)]): llm.async_register_api(self.hass, HomeLLMAPI(self.hass)) @@ -384,7 +384,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom step_id="install_local_wheels", progress_action="install_local_wheels", ) - + if self.install_wheel_task and not self.install_wheel_task.done(): return self.async_show_progress( progress_task=self.install_wheel_task, @@ -491,7 +491,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom step_id="download", progress_action="download", ) - + if self.download_task and not self.download_task.done(): return self.async_show_progress( progress_task=self.download_task, @@ -510,7 +510,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom self.download_task = None return self.async_show_progress_done(next_step_id=next_step) - + async def _async_validate_generic_openai(self, user_input: dict) -> tuple: """ Validates a connection to an OpenAI compatible API server and that the model exists on the remote server @@ -587,7 +587,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom except Exception as ex: _LOGGER.info("Connection error was: %s", repr(ex)) return "failed_to_connect", ex, [] - + async def _async_validate_ollama(self, user_input: dict) -> tuple: """ Validates a connection to ollama and that the model exists on the remote server @@ -619,7 +619,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom model_name = self.model_config[CONF_CHAT_MODEL] if model["name"] == model_name: return (None, None, []) - + return "missing_model_api", None, [x["name"] for x in models_result["models"]] except Exception as ex: @@ -676,7 +676,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom return self.async_show_form( step_id="remote_model", data_schema=schema, errors=errors, description_placeholders=description_placeholders, last_step=False ) - + async def async_step_model_parameters( self, user_input: dict[str, Any] | None = None ) -> FlowResult: @@ -706,7 +706,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom selected_default_options[CONF_PROMPT] = selected_default_options[CONF_PROMPT].replace("", tools) selected_default_options[CONF_PROMPT] = selected_default_options[CONF_PROMPT].replace("", area) selected_default_options[CONF_PROMPT] = selected_default_options[CONF_PROMPT].replace("", user_instruction) - + schema = vol.Schema(local_llama_config_option_schema(self.hass, selected_default_options, backend_type)) if user_input: @@ -718,7 +718,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)): errors["base"] = "missing_gbnf_file" description_placeholders["filename"] = filename - + if user_input.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES): filename = user_input.get(CONF_IN_CONTEXT_EXAMPLES_FILE, DEFAULT_IN_CONTEXT_EXAMPLES_FILE) if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)): @@ -727,7 +727,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom if user_input[CONF_LLM_HASS_API] == "none": user_input.pop(CONF_LLM_HASS_API) - + if len(errors) == 0: try: # validate input @@ -794,7 +794,7 @@ class OptionsFlow(config_entries.OptionsFlow): if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)): errors["base"] = "missing_gbnf_file" description_placeholders["filename"] = filename - + if user_input.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES): filename = user_input.get(CONF_IN_CONTEXT_EXAMPLES_FILE, DEFAULT_IN_CONTEXT_EXAMPLES_FILE) if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)): @@ -806,7 +806,7 @@ class OptionsFlow(config_entries.OptionsFlow): if len(errors) == 0: return self.async_create_entry(title="Local LLM Conversation", data=user_input) - + schema = local_llama_config_option_schema( self.hass, self.config_entry.options, diff --git a/custom_components/llama_conversation/const.py b/custom_components/llama_conversation/const.py index fa713c7..9ae00e5 100644 --- a/custom_components/llama_conversation/const.py +++ b/custom_components/llama_conversation/const.py @@ -119,8 +119,8 @@ CONF_SELECTED_LANGUAGE_OPTIONS = [ "en", "de", "fr", "es", "pl"] CONF_DOWNLOADED_MODEL_QUANTIZATION = "downloaded_model_quantization" CONF_DOWNLOADED_MODEL_QUANTIZATION_OPTIONS = [ "Q4_0", "Q4_1", "Q5_0", "Q5_1", "IQ2_XXS", "IQ2_XS", "IQ2_S", "IQ2_M", "IQ1_S", "IQ1_M", - "Q2_K", "Q2_K_S", "IQ3_XXS", "IQ3_S", "IQ3_M", "Q3_K", "IQ3_XS", "Q3_K_S", "Q3_K_M", "Q3_K_L", - "IQ4_NL", "IQ4_XS", "Q4_K", "Q4_K_S", "Q4_K_M", "Q5_K", "Q5_K_S", "Q5_K_M", "Q6_K", "Q8_0", + "Q2_K", "Q2_K_S", "IQ3_XXS", "IQ3_S", "IQ3_M", "Q3_K", "IQ3_XS", "Q3_K_S", "Q3_K_M", "Q3_K_L", + "IQ4_NL", "IQ4_XS", "Q4_K", "Q4_K_S", "Q4_K_M", "Q5_K", "Q5_K_S", "Q5_K_M", "Q6_K", "Q8_0", "F16", "BF16", "F32" ] DEFAULT_DOWNLOADED_MODEL_QUANTIZATION = "Q4_K_M" diff --git a/custom_components/llama_conversation/conversation.py b/custom_components/llama_conversation/conversation.py index 31a9663..1f25b25 100644 --- a/custom_components/llama_conversation/conversation.py +++ b/custom_components/llama_conversation/conversation.py @@ -141,7 +141,7 @@ CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) async def update_listener(hass: HomeAssistant, entry: ConfigEntry): """Handle options update.""" hass.data[DOMAIN][entry.entry_id] = entry - + # call update handler agent: LocalLLMAgent = entry.runtime_data await hass.async_add_executor_job(agent._update_options) @@ -244,7 +244,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent): if set(self.in_context_examples[0].keys()) != set(["type", "request", "tool", "response" ]): raise Exception("ICL csv file did not have 2 columns: service & response") - + if len(self.in_context_examples) == 0: _LOGGER.warning(f"There were no in context learning examples found in the file '{filename}'!") self.in_context_examples = None @@ -259,7 +259,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent): self._attr_supported_features = ( conversation.ConversationEntityFeature.CONTROL ) - + if self.entry.options.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES, DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES): self._load_icl_examples(self.entry.options.get(CONF_IN_CONTEXT_EXAMPLES_FILE, DEFAULT_IN_CONTEXT_EXAMPLES_FILE)) else: @@ -276,7 +276,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent): def supported_languages(self) -> list[str] | Literal["*"]: """Return a list of supported languages.""" return MATCH_ALL - + def _load_model(self, entry: ConfigEntry) -> None: """Load the model on the backend. Implemented by sub-classes""" raise NotImplementedError() @@ -286,7 +286,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent): return await self.hass.async_add_executor_job( self._load_model, entry ) - + def _generate(self, conversation: dict) -> str: """Call the backend to generate a response from the conversation. Implemented by sub-classes""" raise NotImplementedError() @@ -296,7 +296,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent): return await self.hass.async_add_executor_job( self._generate, conversation ) - + def _warn_context_size(self): num_entities = len(self._async_get_exposed_entities()[0]) context_size = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH) @@ -334,17 +334,17 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent): service_call_pattern = re.compile(service_call_regex, flags=re.MULTILINE) except Exception as err: _LOGGER.exception("There was a problem compiling the service call regex") - + intent_response = intent.IntentResponse(language=user_input.language) intent_response.async_set_error( intent.IntentResponseErrorCode.UNKNOWN, f"Sorry, there was a problem compiling the service call regex: {err}", ) - + return ConversationResult( response=intent_response, conversation_id=user_input.conversation_id ) - + llm_api: llm.APIInstance | None = None if self.entry.options.get(CONF_LLM_HASS_API): try: @@ -372,7 +372,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent): ) message_history = [ _convert_content(content) for content in chat_log.content ] - + # re-generate prompt if necessary if len(message_history) == 0 or refresh_system_prompt: try: @@ -387,9 +387,9 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent): return ConversationResult( response=intent_response, conversation_id=user_input.conversation_id ) - + system_prompt = { "role": "system", "message": message } - + if len(message_history) == 0: message_history.append(system_prompt) else: @@ -403,7 +403,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent): except Exception as err: _LOGGER.exception("There was a problem talking to the backend") - + intent_response = intent.IntentResponse(language=user_input.language) intent_response.async_set_error( intent.IntentResponseErrorCode.FAILED_TO_HANDLE, @@ -412,13 +412,13 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent): return ConversationResult( response=intent_response, conversation_id=user_input.conversation_id ) - + # remove end of text token if it was returned response = response.replace(template_desc["assistant"]["suffix"], "") - # remove think blocks + # remove think blocks response = re.sub(rf"^.*?{template_desc["chain_of_thought"]["suffix"]}", "", response, flags=re.DOTALL) - + message_history.append({"role": "assistant", "message": response}) if remember_conversation: if remember_num_interactions and len(message_history) > (remember_num_interactions * 2) + 1: @@ -460,7 +460,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent): vol.Required("name"): str, vol.Required("arguments"): dict, }) - + try: schema_to_validate(parsed_tool_call) except vol.Error as ex: @@ -487,7 +487,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent): # convert string "tuple" to a list for RGB colors if "rgb_color" in args_dict and isinstance(args_dict["rgb_color"], str): args_dict["rgb_color"] = [ int(x) for x in args_dict["rgb_color"][1:-1].split(",") ] - + if llm_api.api.id == HOME_LLM_API_ID: to_say = to_say + parsed_tool_call.pop("to_say", "") tool_input = llm.ToolInput( @@ -534,7 +534,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent): except Exception as err: _LOGGER.exception("There was a problem talking to the backend") - + intent_response = intent.IntentResponse(language=user_input.language) intent_response.async_set_error( intent.IntentResponseErrorCode.FAILED_TO_HANDLE, @@ -546,7 +546,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent): message_history.append({"role": "assistant", "message": response}) message_history.append({"role": "assistant", "message": to_say}) - + # generate intent response to Home Assistant intent_response = intent.IntentResponse(language=user_input.language) intent_response.async_set_speech(to_say.strip()) @@ -577,7 +577,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent): if entity: if entity.aliases: attributes["aliases"] = entity.aliases - + if entity.unit_of_measurement: attributes["state"] = attributes["state"] + " " + entity.unit_of_measurement @@ -587,13 +587,13 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent): area_id = device.area_id if entity and entity.area_id: area_id = entity.area_id - + if area_id: area = area_registry.async_get_area(entity.area_id) if area: attributes["area_id"] = area.id attributes["area_name"] = area.name - + entity_states[state.entity_id] = attributes domains.add(state.domain) @@ -627,7 +627,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent): _LOGGER.debug(formatted_prompt) return formatted_prompt - + def _format_tool(self, name: str, parameters: vol.Schema, description: str): style = self.entry.options.get(CONF_TOOL_FORMAT, DEFAULT_TOOL_FORMAT) @@ -636,10 +636,10 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent): if description: result = result + f" - {description}" return result - + raw_parameters: list = voluptuous_serialize.convert( parameters, custom_serializer=custom_custom_serializer) - + # handle vol.Any in the key side of things processed_parameters = [] for param in raw_parameters: @@ -685,9 +685,9 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent): } } } - + raise Exception(f"Unknown tool format {style}") - + def _generate_icl_examples(self, num_examples, entity_names): entity_names = entity_names[:] entity_domains = set([x.split(".")[0] for x in entity_names]) @@ -699,14 +699,14 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent): x for x in self.in_context_examples if x["type"] in entity_domains ] - + random.shuffle(in_context_examples) random.shuffle(entity_names) num_examples_to_generate = min(num_examples, len(in_context_examples)) if num_examples_to_generate < num_examples: _LOGGER.warning(f"Attempted to generate {num_examples} ICL examples for conversation, but only {len(in_context_examples)} are available!") - + examples = [] for _ in range(num_examples_to_generate): chosen_example = in_context_examples.pop() @@ -748,7 +748,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent): "arguments": tool_arguments } }) - + return examples def _generate_system_prompt(self, prompt_template: str, llm_api: llm.APIInstance | None) -> str: @@ -796,7 +796,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent): state = attributes["state"] exposed_attributes = expose_attributes(attributes) str_attributes = ";".join([state] + exposed_attributes) - + formatted_devices = formatted_devices + f"{name} '{attributes.get('friendly_name')}' = {str_attributes}\n" devices.append({ "entity_id": name, @@ -828,7 +828,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent): for domain in domains: if domain not in SERVICE_TOOL_ALLOWED_DOMAINS: continue - + # scripts show up as individual services if domain == "script" and not scripts_added: all_services.extend([ @@ -839,7 +839,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent): ]) scripts_added = True continue - + for name, service in service_dict.get(domain, {}).items(): if name not in SERVICE_TOOL_ALLOWED_SERVICES: continue @@ -856,13 +856,13 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent): self._format_tool(*tool) for tool in all_services ] - + else: tools = [ self._format_tool(tool.name, tool.parameters, tool.description) for tool in llm_api.tools ] - + if self.entry.options.get(CONF_TOOL_FORMAT, DEFAULT_TOOL_FORMAT) == TOOL_FORMAT_MINIMAL: formatted_tools = ", ".join(tools) else: @@ -883,7 +883,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent): if self.in_context_examples and llm_api: num_examples = int(self.entry.options.get(CONF_NUM_IN_CONTEXT_EXAMPLES, DEFAULT_NUM_IN_CONTEXT_EXAMPLES)) render_variables["response_examples"] = self._generate_icl_examples(num_examples, list(entities_to_expose.keys())) - + return template.Template(prompt_template, self.hass).async_render( render_variables, parse_result=False, @@ -910,7 +910,7 @@ class LlamaCppAgent(LocalLLMAgent): if not self.model_path: raise Exception(f"Model was not found at '{self.model_path}'!") - + validate_llama_cpp_python_installation() # don't import it until now because the wheel is installed by config_flow.py @@ -921,10 +921,10 @@ class LlamaCppAgent(LocalLLMAgent): install_result = install_llama_cpp_python(self.hass.config.config_dir) if not install_result == True: raise ConfigEntryError("llama-cpp-python was not installed on startup and re-installing it led to an error!") - + validate_llama_cpp_python_installation() self.llama_cpp_module = importlib.import_module("llama_cpp") - + Llama = getattr(self.llama_cpp_module, "Llama") _LOGGER.debug(f"Loading model '{self.model_path}'...") @@ -948,14 +948,14 @@ class LlamaCppAgent(LocalLLMAgent): self.grammar = None if entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR): self._load_grammar(entry.options.get(CONF_GBNF_GRAMMAR_FILE, DEFAULT_GBNF_GRAMMAR_FILE)) - + # TODO: check about disk caching # self.llm.set_cache(self.llama_cpp_module.LlamaDiskCache( # capacity_bytes=(512 * 10e8), # cache_dir=os.path.join(self.hass.config.media_dirs.get("local", self.hass.config.path("media")), "kv_cache") # )) - + self.remove_prompt_caching_listener = None self.last_cache_prime = None self.last_updated_entities = {} @@ -1025,7 +1025,7 @@ class LlamaCppAgent(LocalLLMAgent): if self.loaded_model_settings[CONF_PROMPT_CACHING_ENABLED] != self.entry.options.get(CONF_PROMPT_CACHING_ENABLED, DEFAULT_PROMPT_CACHING_ENABLED) or \ model_reloaded: self.loaded_model_settings[CONF_PROMPT_CACHING_ENABLED] = self.entry.options.get(CONF_PROMPT_CACHING_ENABLED, DEFAULT_PROMPT_CACHING_ENABLED) - + async def cache_current_prompt(_now): await self._async_cache_prompt(None, None, None) async_call_later(self.hass, 1.0, cache_current_prompt) @@ -1039,7 +1039,7 @@ class LlamaCppAgent(LocalLLMAgent): # ignore sorting if prompt caching is disabled if not self.entry.options.get(CONF_PROMPT_CACHING_ENABLED, DEFAULT_PROMPT_CACHING_ENABLED): return entities, domains - + entity_order = { name: None for name in entities.keys() } entity_order.update(self.last_updated_entities) @@ -1050,7 +1050,7 @@ class LlamaCppAgent(LocalLLMAgent): return (False, '', item_name) else: return (True, last_updated, '') - + # Sort the items based on the sort_key function sorted_items = sorted(list(entity_order.items()), key=sort_key) @@ -1066,9 +1066,9 @@ class LlamaCppAgent(LocalLLMAgent): if enabled and not self.remove_prompt_caching_listener: _LOGGER.info("enabling prompt caching...") - entity_ids = [ + entity_ids = [ state.entity_id for state in self.hass.states.async_all() \ - if async_should_expose(self.hass, CONVERSATION_DOMAIN, state.entity_id) + if async_should_expose(self.hass, CONVERSATION_DOMAIN, state.entity_id) ] _LOGGER.debug(f"watching entities: {entity_ids}") @@ -1107,27 +1107,27 @@ class LlamaCppAgent(LocalLLMAgent): # if a refresh is already scheduled then exit if self.cache_refresh_after_cooldown: return - + # if we are inside the cooldown period, request a refresh and exit current_time = time.time() fastest_prime_interval = self.entry.options.get(CONF_PROMPT_CACHING_INTERVAL, DEFAULT_PROMPT_CACHING_INTERVAL) if self.last_cache_prime and current_time - self.last_cache_prime < fastest_prime_interval: self.cache_refresh_after_cooldown = True return - + # try to acquire the lock, if we are still running for some reason, request a refresh and exit lock_acquired = self.model_lock.acquire(False) if not lock_acquired: self.cache_refresh_after_cooldown = True return - + try: raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT) prompt = self._format_prompt([ { "role": "system", "message": self._generate_system_prompt(raw_prompt, llm_api)}, { "role": "user", "message": "" } ], include_generation_prompt=False) - + input_tokens = self.llm.tokenize( prompt.encode(), add_bos=False ) @@ -1154,10 +1154,10 @@ class LlamaCppAgent(LocalLLMAgent): finally: self.model_lock.release() - + # schedule a refresh using async_call_later # if the flag is set after the delay then we do another refresh - + @callback async def refresh_if_requested(_now): if self.cache_refresh_after_cooldown: @@ -1172,8 +1172,8 @@ class LlamaCppAgent(LocalLLMAgent): refresh_delay = self.entry.options.get(CONF_PROMPT_CACHING_INTERVAL, DEFAULT_PROMPT_CACHING_INTERVAL) async_call_later(self.hass, float(refresh_delay), refresh_if_requested) - - + + def _generate(self, conversation: dict) -> str: prompt = self._format_prompt(conversation) @@ -1224,7 +1224,7 @@ class LlamaCppAgent(LocalLLMAgent): result = self.llm.detokenize(result_tokens).decode() return result - + class GenericOpenAIAPIAgent(LocalLLMAgent): api_host: str api_key: str @@ -1237,7 +1237,7 @@ class GenericOpenAIAPIAgent(LocalLLMAgent): ssl=entry.data[CONF_SSL], path="" ) - + self.api_key = entry.data.get(CONF_OPENAI_API_KEY) self.model_name = entry.data.get(CONF_CHAT_MODEL) @@ -1259,7 +1259,7 @@ class GenericOpenAIAPIAgent(LocalLLMAgent): request_params["prompt"] = self._format_prompt(conversation) return endpoint, request_params - + def _extract_response(self, response_json: dict) -> str: choices = response_json["choices"] if choices[0]["finish_reason"] != "stop": @@ -1269,14 +1269,14 @@ class GenericOpenAIAPIAgent(LocalLLMAgent): return choices[0]["message"]["content"] else: return choices[0]["text"] - + async def _async_generate(self, conversation: dict) -> str: max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS) temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE) top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P) timeout = self.entry.options.get(CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT) use_chat_api = self.entry.options.get(CONF_REMOTE_USE_CHAT_ENDPOINT, DEFAULT_REMOTE_USE_CHAT_ENDPOINT) - + request_params = { "model": self.model_name, @@ -1284,12 +1284,12 @@ class GenericOpenAIAPIAgent(LocalLLMAgent): "temperature": temperature, "top_p": top_p, } - + if use_chat_api: endpoint, additional_params = self._chat_completion_params(conversation) else: endpoint, additional_params = self._completion_params(conversation) - + request_params.update(additional_params) headers = {} @@ -1318,7 +1318,7 @@ class GenericOpenAIAPIAgent(LocalLLMAgent): _LOGGER.debug(result) return self._extract_response(result) - + class TextGenerationWebuiAgent(GenericOpenAIAPIAgent): admin_key: str @@ -1332,21 +1332,21 @@ class TextGenerationWebuiAgent(GenericOpenAIAPIAgent): if self.admin_key: headers["Authorization"] = f"Bearer {self.admin_key}" - + async with session.get( f"{self.api_host}/v1/internal/model/info", headers=headers ) as response: response.raise_for_status() currently_loaded_result = await response.json() - + loaded_model = currently_loaded_result["model_name"] if loaded_model == self.model_name: _LOGGER.info(f"Model {self.model_name} is already loaded on the remote backend.") return else: _LOGGER.info(f"Model is not {self.model_name} loaded on the remote backend. Loading it now...") - + async with session.post( f"{self.api_host}/v1/internal/model/load", json={ @@ -1381,7 +1381,7 @@ class TextGenerationWebuiAgent(GenericOpenAIAPIAgent): request_params["typical_p"] = self.entry.options.get(CONF_TYPICAL_P, DEFAULT_TYPICAL_P) return endpoint, request_params - + def _completion_params(self, conversation: dict) -> (str, dict): preset = self.entry.options.get(CONF_TEXT_GEN_WEBUI_PRESET) @@ -1396,7 +1396,7 @@ class TextGenerationWebuiAgent(GenericOpenAIAPIAgent): request_params["typical_p"] = self.entry.options.get(CONF_TYPICAL_P, DEFAULT_TYPICAL_P) return endpoint, request_params - + def _extract_response(self, response_json: dict) -> str: choices = response_json["choices"] if choices[0]["finish_reason"] != "stop": @@ -1412,7 +1412,7 @@ class TextGenerationWebuiAgent(GenericOpenAIAPIAgent): return choices[0]["message"]["content"] else: return choices[0]["text"] - + class LlamaCppPythonAPIAgent(GenericOpenAIAPIAgent): """https://llama-cpp-python.readthedocs.io/en/latest/server/""" grammar: str @@ -1438,13 +1438,13 @@ class LlamaCppPythonAPIAgent(GenericOpenAIAPIAgent): request_params["grammar"] = self.grammar return endpoint, request_params - + def _completion_params(self, conversation: dict) -> (str, dict): top_k = int(self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K)) endpoint, request_params = super()._completion_params(conversation) request_params["top_k"] = top_k - + if self.entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR): request_params["grammar"] = self.grammar @@ -1471,14 +1471,14 @@ class OllamaAPIAgent(LocalLLMAgent): session = async_get_clientsession(self.hass) if self.api_key: headers["Authorization"] = f"Bearer {self.api_key}" - + async with session.get( f"{self.api_host}/api/tags", headers=headers, ) as response: response.raise_for_status() currently_downloaded_result = await response.json() - + except Exception as ex: _LOGGER.debug("Connection error was: %s", repr(ex)) raise ConfigEntryNotReady("There was a problem connecting to the remote server") from ex @@ -1486,7 +1486,7 @@ class OllamaAPIAgent(LocalLLMAgent): model_names = [ x["name"] for x in currently_downloaded_result["models"] ] if ":" in self.model_name: if not any([ name == self.model_name for name in model_names]): - raise ConfigEntryNotReady(f"Ollama server does not have the provided model: {self.model_name}") + raise ConfigEntryNotReady(f"Ollama server does not have the provided model: {self.model_name}") elif not any([ name.split(":")[0] == self.model_name for name in model_names ]): raise ConfigEntryNotReady(f"Ollama server does not have the provided model: {self.model_name}") @@ -1506,11 +1506,11 @@ class OllamaAPIAgent(LocalLLMAgent): request_params["raw"] = True # ignore prompt template return endpoint, request_params - - def _extract_response(self, response_json: dict) -> str: + + def _extract_response(self, response_json: dict) -> str: if response_json["done"] not in ["true", True]: _LOGGER.warning("Model response did not end on a stop token (unfinished sentence)") - + # TODO: this doesn't work because ollama caches prompts and doesn't always return the full prompt length # context_len = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH) # max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS) @@ -1521,7 +1521,7 @@ class OllamaAPIAgent(LocalLLMAgent): return response_json["response"] else: return response_json["message"]["content"] - + async def _async_generate(self, conversation: dict) -> str: context_length = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH) max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS) @@ -1533,7 +1533,7 @@ class OllamaAPIAgent(LocalLLMAgent): keep_alive = self.entry.options.get(CONF_OLLAMA_KEEP_ALIVE_MIN, DEFAULT_OLLAMA_KEEP_ALIVE_MIN) use_chat_api = self.entry.options.get(CONF_REMOTE_USE_CHAT_ENDPOINT, DEFAULT_REMOTE_USE_CHAT_ENDPOINT) json_mode = self.entry.options.get(CONF_OLLAMA_JSON_MODE, DEFAULT_OLLAMA_JSON_MODE) - + request_params = { "model": self.model_name, "stream": False, @@ -1550,18 +1550,18 @@ class OllamaAPIAgent(LocalLLMAgent): if json_mode: request_params["format"] = "json" - + if use_chat_api: endpoint, additional_params = self._chat_completion_params(conversation) else: endpoint, additional_params = self._completion_params(conversation) - + request_params.update(additional_params) headers = {} if self.api_key: headers["Authorization"] = f"Bearer {self.api_key}" - + session = async_get_clientsession(self.hass) response = None try: @@ -1580,7 +1580,7 @@ class OllamaAPIAgent(LocalLLMAgent): _LOGGER.debug(f"Request was: {request_params}") _LOGGER.debug(f"Result was: {response}") return f"Failed to communicate with the API! {err}" - + _LOGGER.debug(result) return self._extract_response(result)