fix deprecated configflow behavior

This commit is contained in:
Alex O'Connell
2024-03-02 22:20:55 -05:00
parent 7b0b021b59
commit 6cc3f47096

View File

@@ -252,29 +252,23 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
schema = STEP_INIT_DATA_SCHEMA()
if user_input:
try:
local_backend = is_local_backend(user_input[CONF_BACKEND_TYPE])
self.model_config.update(user_input)
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unexpected exception")
errors["base"] = "unknown"
local_backend = is_local_backend(user_input[CONF_BACKEND_TYPE])
self.model_config.update(user_input)
if local_backend:
return await self.async_step_install_local_wheels()
# this check isn't working right now
# for key, value in self.hass.data.get(DOMAIN, {}).items():
# other_backend_type = value.data.get(CONF_BACKEND_TYPE)
# if other_backend_type == BACKEND_TYPE_LLAMA_HF or \
# other_backend_type == BACKEND_TYPE_LLAMA_EXISTING:
# errors["base"] = "other_existing_local"
# schema = STEP_INIT_DATA_SCHEMA(
# backend_type=user_input[CONF_BACKEND_TYPE],
# )
# if "base" not in errors:
# return await self.async_step_install_local_wheels()
else:
if local_backend:
return await self.async_step_install_local_wheels()
# this check isn't working right now
# for key, value in self.hass.data.get(DOMAIN, {}).items():
# other_backend_type = value.data.get(CONF_BACKEND_TYPE)
# if other_backend_type == BACKEND_TYPE_LLAMA_HF or \
# other_backend_type == BACKEND_TYPE_LLAMA_EXISTING:
# errors["base"] = "other_existing_local"
# schema = STEP_INIT_DATA_SCHEMA(
# backend_type=user_input[CONF_BACKEND_TYPE],
# )
# if "base" not in errors:
# return await self.async_step_install_local_wheels()
else:
return await self.async_step_remote_model()
return await self.async_step_remote_model()
elif self.install_wheel_error:
errors["base"] = str(self.install_wheel_error)
self.install_wheel_error = None
@@ -289,6 +283,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
if not user_input:
if self.install_wheel_task:
return self.async_show_progress(
progress_task=self.install_wheel_task,
step_id="install_local_wheels",
progress_action="install_local_wheels",
)
@@ -301,6 +296,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
self.hass.async_create_task(self._async_do_task(self.install_wheel_task))
return self.async_show_progress(
progress_task=self.install_wheel_task,
step_id="install_local_wheels",
progress_action="install_local_wheels",
)
@@ -341,24 +337,18 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
downloaded_model_quantization=self.model_config[CONF_DOWNLOADED_MODEL_QUANTIZATION]
)
if user_input:
try:
self.model_config.update(user_input)
if user_input and "result" not in user_input:
self.model_config.update(user_input)
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unexpected exception")
errors["base"] = "unknown"
if backend_type == BACKEND_TYPE_LLAMA_HF:
return await self.async_step_download()
else:
if backend_type == BACKEND_TYPE_LLAMA_HF:
return await self.async_step_download()
model_file = self.model_config[CONF_DOWNLOADED_MODEL_FILE]
if os.path.exists(model_file):
return await self.async_step_finish()
else:
model_file = self.model_config[CONF_DOWNLOADED_MODEL_FILE]
if os.path.exists(model_file):
return await self.async_step_finish()
else:
errors["base"] = "missing_model_file"
schema = STEP_LOCAL_SETUP_EXISTING_DATA_SCHEMA(model_file)
errors["base"] = "missing_model_file"
schema = STEP_LOCAL_SETUP_EXISTING_DATA_SCHEMA(model_file)
return self.async_show_form(
step_id="local_model", data_schema=schema, errors=errors
@@ -370,6 +360,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
if not user_input:
if self.download_task:
return self.async_show_progress(
progress_task=self.download_task,
step_id="download",
progress_action="download",
)
@@ -385,6 +376,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
self.hass.async_create_task(self._async_do_task(self.download_task))
return self.async_show_progress(
progress_task=self.download_task,
step_id="download",
progress_action="download",
)