populate potential models if the provided one was unavailable

This commit is contained in:
Alex O'Connell
2024-04-11 00:14:07 -04:00
parent 7262a2057a
commit e176e60727
3 changed files with 27 additions and 19 deletions

View File

@@ -170,7 +170,7 @@ def STEP_LOCAL_SETUP_DOWNLOAD_DATA_SCHEMA(*, chat_model=None, downloaded_model_q
}
)
def STEP_REMOTE_SETUP_DATA_SCHEMA(backend_type: str, *, host=None, port=None, ssl=None, chat_model=None):
def STEP_REMOTE_SETUP_DATA_SCHEMA(backend_type: str, *, host=None, port=None, ssl=None, chat_model=None, available_chat_models=[]):
extra1, extra2 = ({}, {})
default_port = DEFAULT_PORT
@@ -188,7 +188,12 @@ def STEP_REMOTE_SETUP_DATA_SCHEMA(backend_type: str, *, host=None, port=None, ss
vol.Required(CONF_HOST, default=host if host else ""): str,
vol.Required(CONF_PORT, default=port if port else default_port): str,
vol.Required(CONF_SSL, default=ssl if ssl else DEFAULT_SSL): bool,
vol.Required(CONF_CHAT_MODEL, default=chat_model if chat_model else DEFAULT_CHAT_MODEL): str,
vol.Required(CONF_CHAT_MODEL, default=chat_model if chat_model else DEFAULT_CHAT_MODEL): SelectSelector(SelectSelectorConfig(
options=available_chat_models,
custom_value=True,
multiple=False,
mode=SelectSelectorMode.DROPDOWN,
)),
**extra1,
vol.Optional(CONF_OPENAI_API_KEY): TextSelector(TextSelectorConfig(type="password")),
**extra2
@@ -297,7 +302,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
self.install_wheel_error = None
return self.async_show_form(
step_id="pick_backend", data_schema=schema, errors=errors
step_id="pick_backend", data_schema=schema, errors=errors, last_step=False
)
async def async_step_install_local_wheels(
@@ -377,7 +382,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
schema = STEP_LOCAL_SETUP_EXISTING_DATA_SCHEMA(model_file)
return self.async_show_form(
step_id="local_model", data_schema=schema, errors=errors, description_placeholders=description_placeholders,
step_id="local_model", data_schema=schema, errors=errors, description_placeholders=description_placeholders, last_step=False
)
async def async_step_download(
@@ -417,7 +422,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
self.download_task = None
return self.async_show_progress_done(next_step_id=next_step)
def _validate_text_generation_webui(self, user_input: dict) -> str:
def _validate_text_generation_webui(self, user_input: dict) -> tuple:
"""
Validates a connection to text-generation-webui and that the model exists on the remote server
@@ -441,15 +446,15 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
for model in models["model_names"]:
if model == self.model_config[CONF_CHAT_MODEL].replace("/", "_"):
return None, None
return None, None, []
return "missing_model_api", None
return "missing_model_api", None, models["model_names"]
except Exception as ex:
_LOGGER.info("Connection error was: %s", repr(ex))
return "failed_to_connect", ex
return "failed_to_connect", ex, []
def _validate_ollama(self, user_input: dict) -> str:
def _validate_ollama(self, user_input: dict) -> tuple:
"""
Validates a connection to ollama and that the model exists on the remote server
@@ -475,15 +480,15 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
model_name = self.model_config[CONF_CHAT_MODEL]
if ":" in model_name:
if model["name"] == model_name:
return None, None
return (None, None, [])
elif model["name"].split(":")[0] == model_name:
return None, None
return (None, None, [])
return "missing_model_api", None
return "missing_model_api", None, [x["name"] for x in models]
except Exception as ex:
_LOGGER.info("Connection error was: %s", repr(ex))
return "failed_to_connect", ex
return "failed_to_connect", ex, []
async def async_step_remote_model(
self, user_input: dict[str, Any] | None = None
@@ -500,13 +505,15 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
# validate and load when using text-generation-webui or ollama
if backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
error_message, ex = await self.hass.async_add_executor_job(
error_message, ex, possible_models = await self.hass.async_add_executor_job(
self._validate_text_generation_webui, user_input
)
elif backend_type == BACKEND_TYPE_OLLAMA:
error_message, ex = await self.hass.async_add_executor_job(
error_message, ex, possible_models = await self.hass.async_add_executor_job(
self._validate_ollama, user_input
)
else:
possible_models = []
if error_message:
errors["base"] = error_message
@@ -518,6 +525,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
port=user_input[CONF_PORT],
ssl=user_input[CONF_SSL],
chat_model=user_input[CONF_CHAT_MODEL],
available_chat_models=possible_models,
)
else:
return await self.async_step_model_parameters()
@@ -527,7 +535,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
errors["base"] = "unknown"
return self.async_show_form(
step_id="remote_model", data_schema=schema, errors=errors, description_placeholders=description_placeholders,
step_id="remote_model", data_schema=schema, errors=errors, description_placeholders=description_placeholders, last_step=False
)
async def async_step_model_parameters(

View File

@@ -43,7 +43,7 @@
"download_model_from_hf": "Download model from HuggingFace",
"use_local_backend": "Use Llama.cpp"
},
"description": "Select the backend for running the model. The options are:\n1. Llama.cpp with a model from HuggingFace\n2. Llama.cpp with a model stored on the disk\n3. [text-generation-webui API](https://github.com/oobabooga/text-generation-webui)\n4. Generic OpenAI API Compatible API\n5. [llama-cpp-python Server](https://llama-cpp-python.readthedocs.io/en/latest/server/)\n6. [Ollama API](https://github.com/jmorganca/ollama/blob/main/docs/api.md)\n\nIf using Llama.cpp locally, make sure you copied the correct wheel file to the same directory as the integration.",
"description": "Select the backend for running the model. The options are:\n1. Llama.cpp with a model from HuggingFace\n2. Llama.cpp with a model stored on the disk\n3. [text-generation-webui API](https://github.com/oobabooga/text-generation-webui)\n4. Generic OpenAI API Compatible API\n5. [llama-cpp-python Server](https://llama-cpp-python.readthedocs.io/en/latest/server/)\n6. [Ollama API](https://github.com/jmorganca/ollama/blob/main/docs/api.md)",
"title": "Select Backend"
},
"model_parameters": {

View File

@@ -215,8 +215,8 @@ async def test_validate_config_flow_ollama(mock_setup_entry, hass: HomeAssistant
# simulate incorrect settings on first try
validate_connections_mock.side_effect = [
("failed_to_connect", Exception("ConnectionError")),
(None, None)
("failed_to_connect", Exception("ConnectionError"), []),
(None, None, [])
]
result3 = await hass.config_entries.flow.async_configure(