mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(mm): more flexible config matching utils
This commit is contained in:
@@ -54,7 +54,7 @@ def get_config_dict_or_raise(config_path: Path | set[Path]) -> dict[str, Any]:
|
||||
raise NotAMatchError(f"unable to load config file(s): {problems}")
|
||||
|
||||
|
||||
def get_class_name_from_config_dict_or_raise(config_path: Path | set[Path]) -> str:
|
||||
def get_class_name_from_config_dict_or_raise(config: Path | set[Path] | dict[str, Any]) -> str:
|
||||
"""Load the diffusers/transformers model config file and return the class name.
|
||||
|
||||
Args:
|
||||
@@ -67,7 +67,8 @@ def get_class_name_from_config_dict_or_raise(config_path: Path | set[Path]) -> s
|
||||
NotAMatch if the config file is missing or does not contain a valid class name.
|
||||
"""
|
||||
|
||||
config = get_config_dict_or_raise(config_path)
|
||||
if not isinstance(config, dict):
|
||||
config = get_config_dict_or_raise(config)
|
||||
|
||||
try:
|
||||
if "_class_name" in config:
|
||||
@@ -79,7 +80,7 @@ def get_class_name_from_config_dict_or_raise(config_path: Path | set[Path]) -> s
|
||||
else:
|
||||
raise ValueError("missing _class_name or architectures field")
|
||||
except Exception as e:
|
||||
raise NotAMatchError(f"unable to determine class name from config file: {config_path}") from e
|
||||
raise NotAMatchError(f"unable to determine class name from config file: {config}") from e
|
||||
|
||||
if not isinstance(config_class_name, str):
|
||||
raise NotAMatchError(f"_class_name or architectures field is not a string: {config_class_name}")
|
||||
@@ -87,7 +88,7 @@ def get_class_name_from_config_dict_or_raise(config_path: Path | set[Path]) -> s
|
||||
return config_class_name
|
||||
|
||||
|
||||
def raise_for_class_name(config_path: Path | set[Path], class_name: str | set[str]) -> None:
|
||||
def raise_for_class_name(config: Path | set[Path] | dict[str, Any], class_name: str | set[str]) -> None:
|
||||
"""Get the class name from the config file and raise NotAMatch if it is not in the expected set.
|
||||
|
||||
Args:
|
||||
@@ -100,7 +101,7 @@ def raise_for_class_name(config_path: Path | set[Path], class_name: str | set[st
|
||||
|
||||
class_name = {class_name} if isinstance(class_name, str) else class_name
|
||||
|
||||
actual_class_name = get_class_name_from_config_dict_or_raise(config_path)
|
||||
actual_class_name = get_class_name_from_config_dict_or_raise(config)
|
||||
if actual_class_name not in class_name:
|
||||
raise NotAMatchError(f"invalid class name from config: {actual_class_name}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user