mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-09 21:58:00 -05:00
populate potential models if the provided one was unavailable
This commit is contained in:
@@ -170,7 +170,7 @@ def STEP_LOCAL_SETUP_DOWNLOAD_DATA_SCHEMA(*, chat_model=None, downloaded_model_q
|
||||
}
|
||||
)
|
||||
|
||||
def STEP_REMOTE_SETUP_DATA_SCHEMA(backend_type: str, *, host=None, port=None, ssl=None, chat_model=None):
|
||||
def STEP_REMOTE_SETUP_DATA_SCHEMA(backend_type: str, *, host=None, port=None, ssl=None, chat_model=None, available_chat_models=[]):
|
||||
|
||||
extra1, extra2 = ({}, {})
|
||||
default_port = DEFAULT_PORT
|
||||
@@ -188,7 +188,12 @@ def STEP_REMOTE_SETUP_DATA_SCHEMA(backend_type: str, *, host=None, port=None, ss
|
||||
vol.Required(CONF_HOST, default=host if host else ""): str,
|
||||
vol.Required(CONF_PORT, default=port if port else default_port): str,
|
||||
vol.Required(CONF_SSL, default=ssl if ssl else DEFAULT_SSL): bool,
|
||||
vol.Required(CONF_CHAT_MODEL, default=chat_model if chat_model else DEFAULT_CHAT_MODEL): str,
|
||||
vol.Required(CONF_CHAT_MODEL, default=chat_model if chat_model else DEFAULT_CHAT_MODEL): SelectSelector(SelectSelectorConfig(
|
||||
options=available_chat_models,
|
||||
custom_value=True,
|
||||
multiple=False,
|
||||
mode=SelectSelectorMode.DROPDOWN,
|
||||
)),
|
||||
**extra1,
|
||||
vol.Optional(CONF_OPENAI_API_KEY): TextSelector(TextSelectorConfig(type="password")),
|
||||
**extra2
|
||||
@@ -297,7 +302,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
self.install_wheel_error = None
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="pick_backend", data_schema=schema, errors=errors
|
||||
step_id="pick_backend", data_schema=schema, errors=errors, last_step=False
|
||||
)
|
||||
|
||||
async def async_step_install_local_wheels(
|
||||
@@ -377,7 +382,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
schema = STEP_LOCAL_SETUP_EXISTING_DATA_SCHEMA(model_file)
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="local_model", data_schema=schema, errors=errors, description_placeholders=description_placeholders,
|
||||
step_id="local_model", data_schema=schema, errors=errors, description_placeholders=description_placeholders, last_step=False
|
||||
)
|
||||
|
||||
async def async_step_download(
|
||||
@@ -417,7 +422,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
self.download_task = None
|
||||
return self.async_show_progress_done(next_step_id=next_step)
|
||||
|
||||
def _validate_text_generation_webui(self, user_input: dict) -> str:
|
||||
def _validate_text_generation_webui(self, user_input: dict) -> tuple:
|
||||
"""
|
||||
Validates a connection to text-generation-webui and that the model exists on the remote server
|
||||
|
||||
@@ -441,15 +446,15 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
|
||||
for model in models["model_names"]:
|
||||
if model == self.model_config[CONF_CHAT_MODEL].replace("/", "_"):
|
||||
return None, None
|
||||
return None, None, []
|
||||
|
||||
return "missing_model_api", None
|
||||
return "missing_model_api", None, models["model_names"]
|
||||
|
||||
except Exception as ex:
|
||||
_LOGGER.info("Connection error was: %s", repr(ex))
|
||||
return "failed_to_connect", ex
|
||||
return "failed_to_connect", ex, []
|
||||
|
||||
def _validate_ollama(self, user_input: dict) -> str:
|
||||
def _validate_ollama(self, user_input: dict) -> tuple:
|
||||
"""
|
||||
Validates a connection to ollama and that the model exists on the remote server
|
||||
|
||||
@@ -475,15 +480,15 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
model_name = self.model_config[CONF_CHAT_MODEL]
|
||||
if ":" in model_name:
|
||||
if model["name"] == model_name:
|
||||
return None, None
|
||||
return (None, None, [])
|
||||
elif model["name"].split(":")[0] == model_name:
|
||||
return None, None
|
||||
return (None, None, [])
|
||||
|
||||
return "missing_model_api", None
|
||||
return "missing_model_api", None, [x["name"] for x in models]
|
||||
|
||||
except Exception as ex:
|
||||
_LOGGER.info("Connection error was: %s", repr(ex))
|
||||
return "failed_to_connect", ex
|
||||
return "failed_to_connect", ex, []
|
||||
|
||||
async def async_step_remote_model(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
@@ -500,13 +505,15 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
|
||||
# validate and load when using text-generation-webui or ollama
|
||||
if backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
|
||||
error_message, ex = await self.hass.async_add_executor_job(
|
||||
error_message, ex, possible_models = await self.hass.async_add_executor_job(
|
||||
self._validate_text_generation_webui, user_input
|
||||
)
|
||||
elif backend_type == BACKEND_TYPE_OLLAMA:
|
||||
error_message, ex = await self.hass.async_add_executor_job(
|
||||
error_message, ex, possible_models = await self.hass.async_add_executor_job(
|
||||
self._validate_ollama, user_input
|
||||
)
|
||||
else:
|
||||
possible_models = []
|
||||
|
||||
if error_message:
|
||||
errors["base"] = error_message
|
||||
@@ -518,6 +525,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
port=user_input[CONF_PORT],
|
||||
ssl=user_input[CONF_SSL],
|
||||
chat_model=user_input[CONF_CHAT_MODEL],
|
||||
available_chat_models=possible_models,
|
||||
)
|
||||
else:
|
||||
return await self.async_step_model_parameters()
|
||||
@@ -527,7 +535,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
errors["base"] = "unknown"
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="remote_model", data_schema=schema, errors=errors, description_placeholders=description_placeholders,
|
||||
step_id="remote_model", data_schema=schema, errors=errors, description_placeholders=description_placeholders, last_step=False
|
||||
)
|
||||
|
||||
async def async_step_model_parameters(
|
||||
|
||||
@@ -43,7 +43,7 @@
|
||||
"download_model_from_hf": "Download model from HuggingFace",
|
||||
"use_local_backend": "Use Llama.cpp"
|
||||
},
|
||||
"description": "Select the backend for running the model. The options are:\n1. Llama.cpp with a model from HuggingFace\n2. Llama.cpp with a model stored on the disk\n3. [text-generation-webui API](https://github.com/oobabooga/text-generation-webui)\n4. Generic OpenAI API Compatible API\n5. [llama-cpp-python Server](https://llama-cpp-python.readthedocs.io/en/latest/server/)\n6. [Ollama API](https://github.com/jmorganca/ollama/blob/main/docs/api.md)\n\nIf using Llama.cpp locally, make sure you copied the correct wheel file to the same directory as the integration.",
|
||||
"description": "Select the backend for running the model. The options are:\n1. Llama.cpp with a model from HuggingFace\n2. Llama.cpp with a model stored on the disk\n3. [text-generation-webui API](https://github.com/oobabooga/text-generation-webui)\n4. Generic OpenAI API Compatible API\n5. [llama-cpp-python Server](https://llama-cpp-python.readthedocs.io/en/latest/server/)\n6. [Ollama API](https://github.com/jmorganca/ollama/blob/main/docs/api.md)",
|
||||
"title": "Select Backend"
|
||||
},
|
||||
"model_parameters": {
|
||||
|
||||
@@ -215,8 +215,8 @@ async def test_validate_config_flow_ollama(mock_setup_entry, hass: HomeAssistant
|
||||
|
||||
# simulate incorrect settings on first try
|
||||
validate_connections_mock.side_effect = [
|
||||
("failed_to_connect", Exception("ConnectionError")),
|
||||
(None, None)
|
||||
("failed_to_connect", Exception("ConnectionError"), []),
|
||||
(None, None, [])
|
||||
]
|
||||
|
||||
result3 = await hass.config_entries.flow.async_configure(
|
||||
|
||||
Reference in New Issue
Block a user