feat(mm): more flexible config matching utils

This commit is contained in:
psychedelicious
2025-10-07 17:50:13 +11:00
parent 303acdb4ac
commit d336aa45f5

View File

@@ -54,7 +54,7 @@ def get_config_dict_or_raise(config_path: Path | set[Path]) -> dict[str, Any]:
raise NotAMatchError(f"unable to load config file(s): {problems}")
def get_class_name_from_config_dict_or_raise(config_path: Path | set[Path]) -> str:
def get_class_name_from_config_dict_or_raise(config: Path | set[Path] | dict[str, Any]) -> str:
"""Load the diffusers/transformers model config file and return the class name.
Args:
@@ -67,7 +67,8 @@ def get_class_name_from_config_dict_or_raise(config_path: Path | set[Path]) -> s
NotAMatch if the config file is missing or does not contain a valid class name.
"""
config = get_config_dict_or_raise(config_path)
if not isinstance(config, dict):
config = get_config_dict_or_raise(config)
try:
if "_class_name" in config:
@@ -79,7 +80,7 @@ def get_class_name_from_config_dict_or_raise(config_path: Path | set[Path]) -> s
else:
raise ValueError("missing _class_name or architectures field")
except Exception as e:
raise NotAMatchError(f"unable to determine class name from config file: {config_path}") from e
raise NotAMatchError(f"unable to determine class name from config file: {config}") from e
if not isinstance(config_class_name, str):
raise NotAMatchError(f"_class_name or architectures field is not a string: {config_class_name}")
@@ -87,7 +88,7 @@ def get_class_name_from_config_dict_or_raise(config_path: Path | set[Path]) -> s
return config_class_name
def raise_for_class_name(config_path: Path | set[Path], class_name: str | set[str]) -> None:
def raise_for_class_name(config: Path | set[Path] | dict[str, Any], class_name: str | set[str]) -> None:
"""Get the class name from the config file and raise NotAMatch if it is not in the expected set.
Args:
@@ -100,7 +101,7 @@ def raise_for_class_name(config_path: Path | set[Path], class_name: str | set[st
class_name = {class_name} if isinstance(class_name, str) else class_name
actual_class_name = get_class_name_from_config_dict_or_raise(config_path)
actual_class_name = get_class_name_from_config_dict_or_raise(config)
if actual_class_name not in class_name:
raise NotAMatchError(f"invalid class name from config: {actual_class_name}")