From d336aa45f505f9fa605ecd8a366583e42d24f4e1 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 7 Oct 2025 17:50:13 +1100 Subject: [PATCH] feat(mm): more flexible config matching utils --- .../model_manager/configs/identification_utils.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/invokeai/backend/model_manager/configs/identification_utils.py b/invokeai/backend/model_manager/configs/identification_utils.py index c2c310a332..ce7d2c792d 100644 --- a/invokeai/backend/model_manager/configs/identification_utils.py +++ b/invokeai/backend/model_manager/configs/identification_utils.py @@ -54,7 +54,7 @@ def get_config_dict_or_raise(config_path: Path | set[Path]) -> dict[str, Any]: raise NotAMatchError(f"unable to load config file(s): {problems}") -def get_class_name_from_config_dict_or_raise(config_path: Path | set[Path]) -> str: +def get_class_name_from_config_dict_or_raise(config: Path | set[Path] | dict[str, Any]) -> str: """Load the diffusers/transformers model config file and return the class name. Args: @@ -67,7 +67,8 @@ def get_class_name_from_config_dict_or_raise(config_path: Path | set[Path]) -> s NotAMatch if the config file is missing or does not contain a valid class name. """ - config = get_config_dict_or_raise(config_path) + if not isinstance(config, dict): + config = get_config_dict_or_raise(config) try: if "_class_name" in config: @@ -79,7 +80,7 @@ def get_class_name_from_config_dict_or_raise(config_path: Path | set[Path]) -> s else: raise ValueError("missing _class_name or architectures field") except Exception as e: - raise NotAMatchError(f"unable to determine class name from config file: {config_path}") from e + raise NotAMatchError(f"unable to determine class name from config file: {config}") from e if not isinstance(config_class_name, str): raise NotAMatchError(f"_class_name or architectures field is not a string: {config_class_name}") @@ -87,7 +88,7 @@ def get_class_name_from_config_dict_or_raise(config_path: Path | set[Path]) -> s return config_class_name -def raise_for_class_name(config_path: Path | set[Path], class_name: str | set[str]) -> None: +def raise_for_class_name(config: Path | set[Path] | dict[str, Any], class_name: str | set[str]) -> None: """Get the class name from the config file and raise NotAMatch if it is not in the expected set. Args: @@ -100,7 +101,7 @@ def raise_for_class_name(config_path: Path | set[Path], class_name: str | set[st class_name = {class_name} if isinstance(class_name, str) else class_name - actual_class_name = get_class_name_from_config_dict_or_raise(config_path) + actual_class_name = get_class_name_from_config_dict_or_raise(config) if actual_class_name not in class_name: raise NotAMatchError(f"invalid class name from config: {actual_class_name}")