diff --git a/custom_components/llama_conversation/config_flow.py b/custom_components/llama_conversation/config_flow.py index aca724d..87acd50 100644 --- a/custom_components/llama_conversation/config_flow.py +++ b/custom_components/llama_conversation/config_flow.py @@ -170,7 +170,7 @@ def STEP_LOCAL_SETUP_DOWNLOAD_DATA_SCHEMA(*, chat_model=None, downloaded_model_q } ) -def STEP_REMOTE_SETUP_DATA_SCHEMA(backend_type: str, *, host=None, port=None, ssl=None, chat_model=None): +def STEP_REMOTE_SETUP_DATA_SCHEMA(backend_type: str, *, host=None, port=None, ssl=None, chat_model=None, available_chat_models=[]): extra1, extra2 = ({}, {}) default_port = DEFAULT_PORT @@ -188,7 +188,12 @@ def STEP_REMOTE_SETUP_DATA_SCHEMA(backend_type: str, *, host=None, port=None, ss vol.Required(CONF_HOST, default=host if host else ""): str, vol.Required(CONF_PORT, default=port if port else default_port): str, vol.Required(CONF_SSL, default=ssl if ssl else DEFAULT_SSL): bool, - vol.Required(CONF_CHAT_MODEL, default=chat_model if chat_model else DEFAULT_CHAT_MODEL): str, + vol.Required(CONF_CHAT_MODEL, default=chat_model if chat_model else DEFAULT_CHAT_MODEL): SelectSelector(SelectSelectorConfig( + options=available_chat_models, + custom_value=True, + multiple=False, + mode=SelectSelectorMode.DROPDOWN, + )), **extra1, vol.Optional(CONF_OPENAI_API_KEY): TextSelector(TextSelectorConfig(type="password")), **extra2 @@ -297,7 +302,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom self.install_wheel_error = None return self.async_show_form( - step_id="pick_backend", data_schema=schema, errors=errors + step_id="pick_backend", data_schema=schema, errors=errors, last_step=False ) async def async_step_install_local_wheels( @@ -377,7 +382,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom schema = STEP_LOCAL_SETUP_EXISTING_DATA_SCHEMA(model_file) return self.async_show_form( - step_id="local_model", data_schema=schema, errors=errors, description_placeholders=description_placeholders, + step_id="local_model", data_schema=schema, errors=errors, description_placeholders=description_placeholders, last_step=False ) async def async_step_download( @@ -417,7 +422,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom self.download_task = None return self.async_show_progress_done(next_step_id=next_step) - def _validate_text_generation_webui(self, user_input: dict) -> str: + def _validate_text_generation_webui(self, user_input: dict) -> tuple: """ Validates a connection to text-generation-webui and that the model exists on the remote server @@ -441,15 +446,15 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom for model in models["model_names"]: if model == self.model_config[CONF_CHAT_MODEL].replace("/", "_"): - return None, None + return None, None, [] - return "missing_model_api", None + return "missing_model_api", None, models["model_names"] except Exception as ex: _LOGGER.info("Connection error was: %s", repr(ex)) - return "failed_to_connect", ex + return "failed_to_connect", ex, [] - def _validate_ollama(self, user_input: dict) -> str: + def _validate_ollama(self, user_input: dict) -> tuple: """ Validates a connection to ollama and that the model exists on the remote server @@ -475,15 +480,15 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom model_name = self.model_config[CONF_CHAT_MODEL] if ":" in model_name: if model["name"] == model_name: - return None, None + return (None, None, []) elif model["name"].split(":")[0] == model_name: - return None, None + return (None, None, []) - return "missing_model_api", None + return "missing_model_api", None, [x["name"] for x in models] except Exception as ex: _LOGGER.info("Connection error was: %s", repr(ex)) - return "failed_to_connect", ex + return "failed_to_connect", ex, [] async def async_step_remote_model( self, user_input: dict[str, Any] | None = None @@ -500,13 +505,15 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom # validate and load when using text-generation-webui or ollama if backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI: - error_message, ex = await self.hass.async_add_executor_job( + error_message, ex, possible_models = await self.hass.async_add_executor_job( self._validate_text_generation_webui, user_input ) elif backend_type == BACKEND_TYPE_OLLAMA: - error_message, ex = await self.hass.async_add_executor_job( + error_message, ex, possible_models = await self.hass.async_add_executor_job( self._validate_ollama, user_input ) + else: + possible_models = [] if error_message: errors["base"] = error_message @@ -518,6 +525,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom port=user_input[CONF_PORT], ssl=user_input[CONF_SSL], chat_model=user_input[CONF_CHAT_MODEL], + available_chat_models=possible_models, ) else: return await self.async_step_model_parameters() @@ -527,7 +535,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom errors["base"] = "unknown" return self.async_show_form( - step_id="remote_model", data_schema=schema, errors=errors, description_placeholders=description_placeholders, + step_id="remote_model", data_schema=schema, errors=errors, description_placeholders=description_placeholders, last_step=False ) async def async_step_model_parameters( diff --git a/custom_components/llama_conversation/translations/en.json b/custom_components/llama_conversation/translations/en.json index bce5d06..8a10af2 100644 --- a/custom_components/llama_conversation/translations/en.json +++ b/custom_components/llama_conversation/translations/en.json @@ -43,7 +43,7 @@ "download_model_from_hf": "Download model from HuggingFace", "use_local_backend": "Use Llama.cpp" }, - "description": "Select the backend for running the model. The options are:\n1. Llama.cpp with a model from HuggingFace\n2. Llama.cpp with a model stored on the disk\n3. [text-generation-webui API](https://github.com/oobabooga/text-generation-webui)\n4. Generic OpenAI API Compatible API\n5. [llama-cpp-python Server](https://llama-cpp-python.readthedocs.io/en/latest/server/)\n6. [Ollama API](https://github.com/jmorganca/ollama/blob/main/docs/api.md)\n\nIf using Llama.cpp locally, make sure you copied the correct wheel file to the same directory as the integration.", + "description": "Select the backend for running the model. The options are:\n1. Llama.cpp with a model from HuggingFace\n2. Llama.cpp with a model stored on the disk\n3. [text-generation-webui API](https://github.com/oobabooga/text-generation-webui)\n4. Generic OpenAI API Compatible API\n5. [llama-cpp-python Server](https://llama-cpp-python.readthedocs.io/en/latest/server/)\n6. [Ollama API](https://github.com/jmorganca/ollama/blob/main/docs/api.md)", "title": "Select Backend" }, "model_parameters": { diff --git a/tests/llama_conversation/test_config_flow.py b/tests/llama_conversation/test_config_flow.py index ab8e8ff..a28a0db 100644 --- a/tests/llama_conversation/test_config_flow.py +++ b/tests/llama_conversation/test_config_flow.py @@ -215,8 +215,8 @@ async def test_validate_config_flow_ollama(mock_setup_entry, hass: HomeAssistant # simulate incorrect settings on first try validate_connections_mock.side_effect = [ - ("failed_to_connect", Exception("ConnectionError")), - (None, None) + ("failed_to_connect", Exception("ConnectionError"), []), + (None, None, []) ] result3 = await hass.config_entries.flow.async_configure(