Clean up whitespace

This commit is contained in:
Simon Redman
2025-05-31 14:18:59 -04:00
parent 30d9f48006
commit 58efdb9601
4 changed files with 106 additions and 106 deletions

View File

@@ -62,7 +62,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: LocalLLMConfigEntry) ->
agent_cls = LlamaCppPythonAPIAgent
elif backend_type == BACKEND_TYPE_OLLAMA:
agent_cls = OllamaAPIAgent
return agent_cls(hass, entry)
# create the agent in an executor job because the constructor calls `open()`
@@ -74,7 +74,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: LocalLLMConfigEntry) ->
# forward setup to platform to register the entity
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
return True
@@ -130,12 +130,12 @@ class HassServiceTool(llm.Tool):
domain, service = tuple(tool_input.tool_args["service"].split("."))
except ValueError:
return { "result": "unknown service" }
target_device = tool_input.tool_args["target_device"]
if domain not in self.ALLOWED_DOMAINS or service not in self.ALLOWED_SERVICES:
return { "result": "unknown service" }
if domain == "script" and service not in ["reload", "turn_on", "turn_off", "toggle"]:
return { "result": "unknown service" }
@@ -153,12 +153,12 @@ class HassServiceTool(llm.Tool):
except Exception:
_LOGGER.exception("Failed to execute service for model")
return { "result": "failed" }
return { "result": "success" }
class HomeLLMAPI(llm.API):
"""
An API that allows calling Home Assistant services to maintain compatibility
An API that allows calling Home Assistant services to maintain compatibility
with the older (v3 and older) Home LLM models
"""

View File

@@ -217,8 +217,8 @@ def STEP_REMOTE_SETUP_DATA_SCHEMA(backend_type: str, *, host=None, port=None, ss
extra1, extra2 = ({}, {})
default_port = DEFAULT_PORT
if backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
extra2[vol.Optional(CONF_TEXT_GEN_WEBUI_ADMIN_KEY)] = TextSelector(TextSelectorConfig(type="password"))
if backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
extra2[vol.Optional(CONF_TEXT_GEN_WEBUI_ADMIN_KEY)] = TextSelector(TextSelectorConfig(type="password"))
elif backend_type == BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER:
default_port = "8000"
elif backend_type == BACKEND_TYPE_OLLAMA:
@@ -259,7 +259,7 @@ def STEP_REMOTE_SETUP_DATA_SCHEMA(backend_type: str, *, host=None, port=None, ss
class BaseLlamaConversationConfigFlow(FlowHandler, ABC):
"""Represent the base config flow for Z-Wave JS."""
"""Represent the base config flow for Local LLM."""
@property
@abstractmethod
@@ -335,7 +335,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
"""Handle the initial step."""
self.model_config = {}
self.options = {}
# make sure the API is registered
if not any([x.id == HOME_LLM_API_ID for x in llm.async_get_apis(self.hass)]):
llm.async_register_api(self.hass, HomeLLMAPI(self.hass))
@@ -384,7 +384,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
step_id="install_local_wheels",
progress_action="install_local_wheels",
)
if self.install_wheel_task and not self.install_wheel_task.done():
return self.async_show_progress(
progress_task=self.install_wheel_task,
@@ -491,7 +491,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
step_id="download",
progress_action="download",
)
if self.download_task and not self.download_task.done():
return self.async_show_progress(
progress_task=self.download_task,
@@ -510,7 +510,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
self.download_task = None
return self.async_show_progress_done(next_step_id=next_step)
async def _async_validate_generic_openai(self, user_input: dict) -> tuple:
"""
Validates a connection to an OpenAI compatible API server and that the model exists on the remote server
@@ -587,7 +587,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
except Exception as ex:
_LOGGER.info("Connection error was: %s", repr(ex))
return "failed_to_connect", ex, []
async def _async_validate_ollama(self, user_input: dict) -> tuple:
"""
Validates a connection to ollama and that the model exists on the remote server
@@ -619,7 +619,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
model_name = self.model_config[CONF_CHAT_MODEL]
if model["name"] == model_name:
return (None, None, [])
return "missing_model_api", None, [x["name"] for x in models_result["models"]]
except Exception as ex:
@@ -676,7 +676,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
return self.async_show_form(
step_id="remote_model", data_schema=schema, errors=errors, description_placeholders=description_placeholders, last_step=False
)
async def async_step_model_parameters(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
@@ -706,7 +706,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
selected_default_options[CONF_PROMPT] = selected_default_options[CONF_PROMPT].replace("<tools>", tools)
selected_default_options[CONF_PROMPT] = selected_default_options[CONF_PROMPT].replace("<area>", area)
selected_default_options[CONF_PROMPT] = selected_default_options[CONF_PROMPT].replace("<user_instruction>", user_instruction)
schema = vol.Schema(local_llama_config_option_schema(self.hass, selected_default_options, backend_type))
if user_input:
@@ -718,7 +718,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)):
errors["base"] = "missing_gbnf_file"
description_placeholders["filename"] = filename
if user_input.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES):
filename = user_input.get(CONF_IN_CONTEXT_EXAMPLES_FILE, DEFAULT_IN_CONTEXT_EXAMPLES_FILE)
if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)):
@@ -727,7 +727,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
if user_input[CONF_LLM_HASS_API] == "none":
user_input.pop(CONF_LLM_HASS_API)
if len(errors) == 0:
try:
# validate input
@@ -794,7 +794,7 @@ class OptionsFlow(config_entries.OptionsFlow):
if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)):
errors["base"] = "missing_gbnf_file"
description_placeholders["filename"] = filename
if user_input.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES):
filename = user_input.get(CONF_IN_CONTEXT_EXAMPLES_FILE, DEFAULT_IN_CONTEXT_EXAMPLES_FILE)
if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)):
@@ -806,7 +806,7 @@ class OptionsFlow(config_entries.OptionsFlow):
if len(errors) == 0:
return self.async_create_entry(title="Local LLM Conversation", data=user_input)
schema = local_llama_config_option_schema(
self.hass,
self.config_entry.options,

View File

@@ -119,8 +119,8 @@ CONF_SELECTED_LANGUAGE_OPTIONS = [ "en", "de", "fr", "es", "pl"]
CONF_DOWNLOADED_MODEL_QUANTIZATION = "downloaded_model_quantization"
CONF_DOWNLOADED_MODEL_QUANTIZATION_OPTIONS = [
"Q4_0", "Q4_1", "Q5_0", "Q5_1", "IQ2_XXS", "IQ2_XS", "IQ2_S", "IQ2_M", "IQ1_S", "IQ1_M",
"Q2_K", "Q2_K_S", "IQ3_XXS", "IQ3_S", "IQ3_M", "Q3_K", "IQ3_XS", "Q3_K_S", "Q3_K_M", "Q3_K_L",
"IQ4_NL", "IQ4_XS", "Q4_K", "Q4_K_S", "Q4_K_M", "Q5_K", "Q5_K_S", "Q5_K_M", "Q6_K", "Q8_0",
"Q2_K", "Q2_K_S", "IQ3_XXS", "IQ3_S", "IQ3_M", "Q3_K", "IQ3_XS", "Q3_K_S", "Q3_K_M", "Q3_K_L",
"IQ4_NL", "IQ4_XS", "Q4_K", "Q4_K_S", "Q4_K_M", "Q5_K", "Q5_K_S", "Q5_K_M", "Q6_K", "Q8_0",
"F16", "BF16", "F32"
]
DEFAULT_DOWNLOADED_MODEL_QUANTIZATION = "Q4_K_M"

View File

@@ -141,7 +141,7 @@ CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
async def update_listener(hass: HomeAssistant, entry: ConfigEntry):
"""Handle options update."""
hass.data[DOMAIN][entry.entry_id] = entry
# call update handler
agent: LocalLLMAgent = entry.runtime_data
await hass.async_add_executor_job(agent._update_options)
@@ -244,7 +244,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
if set(self.in_context_examples[0].keys()) != set(["type", "request", "tool", "response" ]):
raise Exception("ICL csv file did not have 2 columns: service & response")
if len(self.in_context_examples) == 0:
_LOGGER.warning(f"There were no in context learning examples found in the file '{filename}'!")
self.in_context_examples = None
@@ -259,7 +259,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
self._attr_supported_features = (
conversation.ConversationEntityFeature.CONTROL
)
if self.entry.options.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES, DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES):
self._load_icl_examples(self.entry.options.get(CONF_IN_CONTEXT_EXAMPLES_FILE, DEFAULT_IN_CONTEXT_EXAMPLES_FILE))
else:
@@ -276,7 +276,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
def supported_languages(self) -> list[str] | Literal["*"]:
"""Return a list of supported languages."""
return MATCH_ALL
def _load_model(self, entry: ConfigEntry) -> None:
"""Load the model on the backend. Implemented by sub-classes"""
raise NotImplementedError()
@@ -286,7 +286,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
return await self.hass.async_add_executor_job(
self._load_model, entry
)
def _generate(self, conversation: dict) -> str:
"""Call the backend to generate a response from the conversation. Implemented by sub-classes"""
raise NotImplementedError()
@@ -296,7 +296,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
return await self.hass.async_add_executor_job(
self._generate, conversation
)
def _warn_context_size(self):
num_entities = len(self._async_get_exposed_entities()[0])
context_size = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
@@ -334,17 +334,17 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
service_call_pattern = re.compile(service_call_regex, flags=re.MULTILINE)
except Exception as err:
_LOGGER.exception("There was a problem compiling the service call regex")
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, there was a problem compiling the service call regex: {err}",
)
return ConversationResult(
response=intent_response, conversation_id=user_input.conversation_id
)
llm_api: llm.APIInstance | None = None
if self.entry.options.get(CONF_LLM_HASS_API):
try:
@@ -372,7 +372,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
)
message_history = [ _convert_content(content) for content in chat_log.content ]
# re-generate prompt if necessary
if len(message_history) == 0 or refresh_system_prompt:
try:
@@ -387,9 +387,9 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
return ConversationResult(
response=intent_response, conversation_id=user_input.conversation_id
)
system_prompt = { "role": "system", "message": message }
if len(message_history) == 0:
message_history.append(system_prompt)
else:
@@ -403,7 +403,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
except Exception as err:
_LOGGER.exception("There was a problem talking to the backend")
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.FAILED_TO_HANDLE,
@@ -412,13 +412,13 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
return ConversationResult(
response=intent_response, conversation_id=user_input.conversation_id
)
# remove end of text token if it was returned
response = response.replace(template_desc["assistant"]["suffix"], "")
# remove think blocks
# remove think blocks
response = re.sub(rf"^.*?{template_desc["chain_of_thought"]["suffix"]}", "", response, flags=re.DOTALL)
message_history.append({"role": "assistant", "message": response})
if remember_conversation:
if remember_num_interactions and len(message_history) > (remember_num_interactions * 2) + 1:
@@ -460,7 +460,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
vol.Required("name"): str,
vol.Required("arguments"): dict,
})
try:
schema_to_validate(parsed_tool_call)
except vol.Error as ex:
@@ -487,7 +487,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
# convert string "tuple" to a list for RGB colors
if "rgb_color" in args_dict and isinstance(args_dict["rgb_color"], str):
args_dict["rgb_color"] = [ int(x) for x in args_dict["rgb_color"][1:-1].split(",") ]
if llm_api.api.id == HOME_LLM_API_ID:
to_say = to_say + parsed_tool_call.pop("to_say", "")
tool_input = llm.ToolInput(
@@ -534,7 +534,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
except Exception as err:
_LOGGER.exception("There was a problem talking to the backend")
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.FAILED_TO_HANDLE,
@@ -546,7 +546,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
message_history.append({"role": "assistant", "message": response})
message_history.append({"role": "assistant", "message": to_say})
# generate intent response to Home Assistant
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(to_say.strip())
@@ -577,7 +577,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
if entity:
if entity.aliases:
attributes["aliases"] = entity.aliases
if entity.unit_of_measurement:
attributes["state"] = attributes["state"] + " " + entity.unit_of_measurement
@@ -587,13 +587,13 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
area_id = device.area_id
if entity and entity.area_id:
area_id = entity.area_id
if area_id:
area = area_registry.async_get_area(entity.area_id)
if area:
attributes["area_id"] = area.id
attributes["area_name"] = area.name
entity_states[state.entity_id] = attributes
domains.add(state.domain)
@@ -627,7 +627,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
_LOGGER.debug(formatted_prompt)
return formatted_prompt
def _format_tool(self, name: str, parameters: vol.Schema, description: str):
style = self.entry.options.get(CONF_TOOL_FORMAT, DEFAULT_TOOL_FORMAT)
@@ -636,10 +636,10 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
if description:
result = result + f" - {description}"
return result
raw_parameters: list = voluptuous_serialize.convert(
parameters, custom_serializer=custom_custom_serializer)
# handle vol.Any in the key side of things
processed_parameters = []
for param in raw_parameters:
@@ -685,9 +685,9 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
}
}
}
raise Exception(f"Unknown tool format {style}")
def _generate_icl_examples(self, num_examples, entity_names):
entity_names = entity_names[:]
entity_domains = set([x.split(".")[0] for x in entity_names])
@@ -699,14 +699,14 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
x for x in self.in_context_examples
if x["type"] in entity_domains
]
random.shuffle(in_context_examples)
random.shuffle(entity_names)
num_examples_to_generate = min(num_examples, len(in_context_examples))
if num_examples_to_generate < num_examples:
_LOGGER.warning(f"Attempted to generate {num_examples} ICL examples for conversation, but only {len(in_context_examples)} are available!")
examples = []
for _ in range(num_examples_to_generate):
chosen_example = in_context_examples.pop()
@@ -748,7 +748,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
"arguments": tool_arguments
}
})
return examples
def _generate_system_prompt(self, prompt_template: str, llm_api: llm.APIInstance | None) -> str:
@@ -796,7 +796,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
state = attributes["state"]
exposed_attributes = expose_attributes(attributes)
str_attributes = ";".join([state] + exposed_attributes)
formatted_devices = formatted_devices + f"{name} '{attributes.get('friendly_name')}' = {str_attributes}\n"
devices.append({
"entity_id": name,
@@ -828,7 +828,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
for domain in domains:
if domain not in SERVICE_TOOL_ALLOWED_DOMAINS:
continue
# scripts show up as individual services
if domain == "script" and not scripts_added:
all_services.extend([
@@ -839,7 +839,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
])
scripts_added = True
continue
for name, service in service_dict.get(domain, {}).items():
if name not in SERVICE_TOOL_ALLOWED_SERVICES:
continue
@@ -856,13 +856,13 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
self._format_tool(*tool)
for tool in all_services
]
else:
tools = [
self._format_tool(tool.name, tool.parameters, tool.description)
for tool in llm_api.tools
]
if self.entry.options.get(CONF_TOOL_FORMAT, DEFAULT_TOOL_FORMAT) == TOOL_FORMAT_MINIMAL:
formatted_tools = ", ".join(tools)
else:
@@ -883,7 +883,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
if self.in_context_examples and llm_api:
num_examples = int(self.entry.options.get(CONF_NUM_IN_CONTEXT_EXAMPLES, DEFAULT_NUM_IN_CONTEXT_EXAMPLES))
render_variables["response_examples"] = self._generate_icl_examples(num_examples, list(entities_to_expose.keys()))
return template.Template(prompt_template, self.hass).async_render(
render_variables,
parse_result=False,
@@ -910,7 +910,7 @@ class LlamaCppAgent(LocalLLMAgent):
if not self.model_path:
raise Exception(f"Model was not found at '{self.model_path}'!")
validate_llama_cpp_python_installation()
# don't import it until now because the wheel is installed by config_flow.py
@@ -921,10 +921,10 @@ class LlamaCppAgent(LocalLLMAgent):
install_result = install_llama_cpp_python(self.hass.config.config_dir)
if not install_result == True:
raise ConfigEntryError("llama-cpp-python was not installed on startup and re-installing it led to an error!")
validate_llama_cpp_python_installation()
self.llama_cpp_module = importlib.import_module("llama_cpp")
Llama = getattr(self.llama_cpp_module, "Llama")
_LOGGER.debug(f"Loading model '{self.model_path}'...")
@@ -948,14 +948,14 @@ class LlamaCppAgent(LocalLLMAgent):
self.grammar = None
if entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR):
self._load_grammar(entry.options.get(CONF_GBNF_GRAMMAR_FILE, DEFAULT_GBNF_GRAMMAR_FILE))
# TODO: check about disk caching
# self.llm.set_cache(self.llama_cpp_module.LlamaDiskCache(
# capacity_bytes=(512 * 10e8),
# cache_dir=os.path.join(self.hass.config.media_dirs.get("local", self.hass.config.path("media")), "kv_cache")
# ))
self.remove_prompt_caching_listener = None
self.last_cache_prime = None
self.last_updated_entities = {}
@@ -1025,7 +1025,7 @@ class LlamaCppAgent(LocalLLMAgent):
if self.loaded_model_settings[CONF_PROMPT_CACHING_ENABLED] != self.entry.options.get(CONF_PROMPT_CACHING_ENABLED, DEFAULT_PROMPT_CACHING_ENABLED) or \
model_reloaded:
self.loaded_model_settings[CONF_PROMPT_CACHING_ENABLED] = self.entry.options.get(CONF_PROMPT_CACHING_ENABLED, DEFAULT_PROMPT_CACHING_ENABLED)
async def cache_current_prompt(_now):
await self._async_cache_prompt(None, None, None)
async_call_later(self.hass, 1.0, cache_current_prompt)
@@ -1039,7 +1039,7 @@ class LlamaCppAgent(LocalLLMAgent):
# ignore sorting if prompt caching is disabled
if not self.entry.options.get(CONF_PROMPT_CACHING_ENABLED, DEFAULT_PROMPT_CACHING_ENABLED):
return entities, domains
entity_order = { name: None for name in entities.keys() }
entity_order.update(self.last_updated_entities)
@@ -1050,7 +1050,7 @@ class LlamaCppAgent(LocalLLMAgent):
return (False, '', item_name)
else:
return (True, last_updated, '')
# Sort the items based on the sort_key function
sorted_items = sorted(list(entity_order.items()), key=sort_key)
@@ -1066,9 +1066,9 @@ class LlamaCppAgent(LocalLLMAgent):
if enabled and not self.remove_prompt_caching_listener:
_LOGGER.info("enabling prompt caching...")
entity_ids = [
entity_ids = [
state.entity_id for state in self.hass.states.async_all() \
if async_should_expose(self.hass, CONVERSATION_DOMAIN, state.entity_id)
if async_should_expose(self.hass, CONVERSATION_DOMAIN, state.entity_id)
]
_LOGGER.debug(f"watching entities: {entity_ids}")
@@ -1107,27 +1107,27 @@ class LlamaCppAgent(LocalLLMAgent):
# if a refresh is already scheduled then exit
if self.cache_refresh_after_cooldown:
return
# if we are inside the cooldown period, request a refresh and exit
current_time = time.time()
fastest_prime_interval = self.entry.options.get(CONF_PROMPT_CACHING_INTERVAL, DEFAULT_PROMPT_CACHING_INTERVAL)
if self.last_cache_prime and current_time - self.last_cache_prime < fastest_prime_interval:
self.cache_refresh_after_cooldown = True
return
# try to acquire the lock, if we are still running for some reason, request a refresh and exit
lock_acquired = self.model_lock.acquire(False)
if not lock_acquired:
self.cache_refresh_after_cooldown = True
return
try:
raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
prompt = self._format_prompt([
{ "role": "system", "message": self._generate_system_prompt(raw_prompt, llm_api)},
{ "role": "user", "message": "" }
], include_generation_prompt=False)
input_tokens = self.llm.tokenize(
prompt.encode(), add_bos=False
)
@@ -1154,10 +1154,10 @@ class LlamaCppAgent(LocalLLMAgent):
finally:
self.model_lock.release()
# schedule a refresh using async_call_later
# if the flag is set after the delay then we do another refresh
@callback
async def refresh_if_requested(_now):
if self.cache_refresh_after_cooldown:
@@ -1172,8 +1172,8 @@ class LlamaCppAgent(LocalLLMAgent):
refresh_delay = self.entry.options.get(CONF_PROMPT_CACHING_INTERVAL, DEFAULT_PROMPT_CACHING_INTERVAL)
async_call_later(self.hass, float(refresh_delay), refresh_if_requested)
def _generate(self, conversation: dict) -> str:
prompt = self._format_prompt(conversation)
@@ -1224,7 +1224,7 @@ class LlamaCppAgent(LocalLLMAgent):
result = self.llm.detokenize(result_tokens).decode()
return result
class GenericOpenAIAPIAgent(LocalLLMAgent):
api_host: str
api_key: str
@@ -1237,7 +1237,7 @@ class GenericOpenAIAPIAgent(LocalLLMAgent):
ssl=entry.data[CONF_SSL],
path=""
)
self.api_key = entry.data.get(CONF_OPENAI_API_KEY)
self.model_name = entry.data.get(CONF_CHAT_MODEL)
@@ -1259,7 +1259,7 @@ class GenericOpenAIAPIAgent(LocalLLMAgent):
request_params["prompt"] = self._format_prompt(conversation)
return endpoint, request_params
def _extract_response(self, response_json: dict) -> str:
choices = response_json["choices"]
if choices[0]["finish_reason"] != "stop":
@@ -1269,14 +1269,14 @@ class GenericOpenAIAPIAgent(LocalLLMAgent):
return choices[0]["message"]["content"]
else:
return choices[0]["text"]
async def _async_generate(self, conversation: dict) -> str:
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
timeout = self.entry.options.get(CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT)
use_chat_api = self.entry.options.get(CONF_REMOTE_USE_CHAT_ENDPOINT, DEFAULT_REMOTE_USE_CHAT_ENDPOINT)
request_params = {
"model": self.model_name,
@@ -1284,12 +1284,12 @@ class GenericOpenAIAPIAgent(LocalLLMAgent):
"temperature": temperature,
"top_p": top_p,
}
if use_chat_api:
endpoint, additional_params = self._chat_completion_params(conversation)
else:
endpoint, additional_params = self._completion_params(conversation)
request_params.update(additional_params)
headers = {}
@@ -1318,7 +1318,7 @@ class GenericOpenAIAPIAgent(LocalLLMAgent):
_LOGGER.debug(result)
return self._extract_response(result)
class TextGenerationWebuiAgent(GenericOpenAIAPIAgent):
admin_key: str
@@ -1332,21 +1332,21 @@ class TextGenerationWebuiAgent(GenericOpenAIAPIAgent):
if self.admin_key:
headers["Authorization"] = f"Bearer {self.admin_key}"
async with session.get(
f"{self.api_host}/v1/internal/model/info",
headers=headers
) as response:
response.raise_for_status()
currently_loaded_result = await response.json()
loaded_model = currently_loaded_result["model_name"]
if loaded_model == self.model_name:
_LOGGER.info(f"Model {self.model_name} is already loaded on the remote backend.")
return
else:
_LOGGER.info(f"Model is not {self.model_name} loaded on the remote backend. Loading it now...")
async with session.post(
f"{self.api_host}/v1/internal/model/load",
json={
@@ -1381,7 +1381,7 @@ class TextGenerationWebuiAgent(GenericOpenAIAPIAgent):
request_params["typical_p"] = self.entry.options.get(CONF_TYPICAL_P, DEFAULT_TYPICAL_P)
return endpoint, request_params
def _completion_params(self, conversation: dict) -> (str, dict):
preset = self.entry.options.get(CONF_TEXT_GEN_WEBUI_PRESET)
@@ -1396,7 +1396,7 @@ class TextGenerationWebuiAgent(GenericOpenAIAPIAgent):
request_params["typical_p"] = self.entry.options.get(CONF_TYPICAL_P, DEFAULT_TYPICAL_P)
return endpoint, request_params
def _extract_response(self, response_json: dict) -> str:
choices = response_json["choices"]
if choices[0]["finish_reason"] != "stop":
@@ -1412,7 +1412,7 @@ class TextGenerationWebuiAgent(GenericOpenAIAPIAgent):
return choices[0]["message"]["content"]
else:
return choices[0]["text"]
class LlamaCppPythonAPIAgent(GenericOpenAIAPIAgent):
"""https://llama-cpp-python.readthedocs.io/en/latest/server/"""
grammar: str
@@ -1438,13 +1438,13 @@ class LlamaCppPythonAPIAgent(GenericOpenAIAPIAgent):
request_params["grammar"] = self.grammar
return endpoint, request_params
def _completion_params(self, conversation: dict) -> (str, dict):
top_k = int(self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K))
endpoint, request_params = super()._completion_params(conversation)
request_params["top_k"] = top_k
if self.entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR):
request_params["grammar"] = self.grammar
@@ -1471,14 +1471,14 @@ class OllamaAPIAgent(LocalLLMAgent):
session = async_get_clientsession(self.hass)
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
async with session.get(
f"{self.api_host}/api/tags",
headers=headers,
) as response:
response.raise_for_status()
currently_downloaded_result = await response.json()
except Exception as ex:
_LOGGER.debug("Connection error was: %s", repr(ex))
raise ConfigEntryNotReady("There was a problem connecting to the remote server") from ex
@@ -1486,7 +1486,7 @@ class OllamaAPIAgent(LocalLLMAgent):
model_names = [ x["name"] for x in currently_downloaded_result["models"] ]
if ":" in self.model_name:
if not any([ name == self.model_name for name in model_names]):
raise ConfigEntryNotReady(f"Ollama server does not have the provided model: {self.model_name}")
raise ConfigEntryNotReady(f"Ollama server does not have the provided model: {self.model_name}")
elif not any([ name.split(":")[0] == self.model_name for name in model_names ]):
raise ConfigEntryNotReady(f"Ollama server does not have the provided model: {self.model_name}")
@@ -1506,11 +1506,11 @@ class OllamaAPIAgent(LocalLLMAgent):
request_params["raw"] = True # ignore prompt template
return endpoint, request_params
def _extract_response(self, response_json: dict) -> str:
def _extract_response(self, response_json: dict) -> str:
if response_json["done"] not in ["true", True]:
_LOGGER.warning("Model response did not end on a stop token (unfinished sentence)")
# TODO: this doesn't work because ollama caches prompts and doesn't always return the full prompt length
# context_len = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
# max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
@@ -1521,7 +1521,7 @@ class OllamaAPIAgent(LocalLLMAgent):
return response_json["response"]
else:
return response_json["message"]["content"]
async def _async_generate(self, conversation: dict) -> str:
context_length = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
@@ -1533,7 +1533,7 @@ class OllamaAPIAgent(LocalLLMAgent):
keep_alive = self.entry.options.get(CONF_OLLAMA_KEEP_ALIVE_MIN, DEFAULT_OLLAMA_KEEP_ALIVE_MIN)
use_chat_api = self.entry.options.get(CONF_REMOTE_USE_CHAT_ENDPOINT, DEFAULT_REMOTE_USE_CHAT_ENDPOINT)
json_mode = self.entry.options.get(CONF_OLLAMA_JSON_MODE, DEFAULT_OLLAMA_JSON_MODE)
request_params = {
"model": self.model_name,
"stream": False,
@@ -1550,18 +1550,18 @@ class OllamaAPIAgent(LocalLLMAgent):
if json_mode:
request_params["format"] = "json"
if use_chat_api:
endpoint, additional_params = self._chat_completion_params(conversation)
else:
endpoint, additional_params = self._completion_params(conversation)
request_params.update(additional_params)
headers = {}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
session = async_get_clientsession(self.hass)
response = None
try:
@@ -1580,7 +1580,7 @@ class OllamaAPIAgent(LocalLLMAgent):
_LOGGER.debug(f"Request was: {request_params}")
_LOGGER.debug(f"Result was: {response}")
return f"Failed to communicate with the API! {err}"
_LOGGER.debug(result)
return self._extract_response(result)