tidy(mm): clarify that model id utils are private

This commit is contained in:
psychedelicious
2025-09-24 18:07:51 +10:00
parent d4823b6869
commit bbecc86d0f

View File

@@ -96,8 +96,11 @@ class NotAMatch(Exception):
DEFAULTS_PRECISION = Literal["fp16", "fp32"]
# These utility functions are tightly coupled to the config classes below in order to make the process of raising
# NotAMatch exceptions as easy and consistent as possible.
def get_config_or_raise(
def _get_config_or_raise(
config_class: type,
config_path: Path,
) -> dict[str, Any]:
@@ -112,14 +115,14 @@ def get_config_or_raise(
raise NotAMatch(config_class, f"unable to load config file: {config_path}") from e
def raise_for_class_names(
def _validate_class_names(
config_class: type,
config_path: Path,
valid_class_names: set[str],
) -> None:
"""Raise NotAMatch if the config file is missing or does not contain a valid class name."""
config = get_config_or_raise(config_class, config_path)
config = _get_config_or_raise(config_class, config_path)
try:
if "_class_name" in config:
@@ -135,8 +138,8 @@ def raise_for_class_names(
raise NotAMatch(config_class, f"model class is not one of {valid_class_names}, got {config_class_name}")
def matches_overrides(
config_class: "Type[AnyModelConfig]",
def _validate_overrides(
config_class: type,
provided_overrides: dict[str, Any],
valid_overrides: dict[str, Any],
) -> bool:
@@ -389,7 +392,7 @@ class T5EncoderConfig(T5EncoderConfigBase, ModelConfigBase):
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
if matches_overrides(
if _validate_overrides(
config_class=cls,
provided_overrides=fields,
valid_overrides=cls.VALID_OVERRIDES,
@@ -399,7 +402,7 @@ class T5EncoderConfig(T5EncoderConfigBase, ModelConfigBase):
if mod.path.is_file():
raise NotAMatch(cls, "model path is a file, not a directory")
raise_for_class_names(
_validate_class_names(
config_class=cls,
config_path=mod.path / "text_encoder_2" / "config.json",
valid_class_names=cls.VALID_CLASS_NAMES,
@@ -428,7 +431,7 @@ class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, ModelConfigBase):
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
if matches_overrides(
if _validate_overrides(
config_class=cls,
provided_overrides=fields,
valid_overrides=cls.VALID_OVERRIDES,
@@ -436,7 +439,7 @@ class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, ModelConfigBase):
return cls(**fields)
# Heuristic: Look for the T5EncoderModel class name in the config
raise_for_class_names(
_validate_class_names(
config_class=cls,
config_path=mod.path / "text_encoder_2" / "config.json",
valid_class_names=cls.VALID_CLASS_NAMES,
@@ -464,7 +467,7 @@ class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase):
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
if matches_overrides(
if _validate_overrides(
config_class=cls,
provided_overrides=fields,
valid_overrides=cls.VALID_OVERRIDES,
@@ -519,7 +522,7 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
if matches_overrides(
if _validate_overrides(
config_class=cls,
provided_overrides=fields,
valid_overrides=cls.VALID_OVERRIDES,
@@ -595,7 +598,7 @@ class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase):
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
if matches_overrides(
if _validate_overrides(
config_class=cls,
provided_overrides=fields,
valid_overrides=cls.VALID_OVERRIDES,
@@ -634,7 +637,7 @@ class VAECheckpointConfig(VAEConfigBase, CheckpointConfigBase, ModelConfigBase):
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
if matches_overrides(
if _validate_overrides(
config_class=cls,
provided_overrides=fields,
valid_overrides=cls.VALID_OVERRIDES,
@@ -681,7 +684,7 @@ class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase):
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
if matches_overrides(
if _validate_overrides(
config_class=cls,
provided_overrides=fields,
valid_overrides=cls.VALID_OVERRIDES,
@@ -691,7 +694,7 @@ class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase):
if mod.path.is_file():
raise NotAMatch(cls, "model path is a file, not a directory")
raise_for_class_names(
_validate_class_names(
config_class=cls,
config_path=mod.path / "config.json",
valid_class_names=cls.VALID_CLASS_NAMES,
@@ -721,7 +724,7 @@ class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase):
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
config = get_config_or_raise(cls, mod.path / "config.json")
config = _get_config_or_raise(cls, mod.path / "config.json")
if cls._config_looks_like_sdxl(config):
return BaseModelType.StableDiffusionXL
elif cls._name_looks_like_sdxl(mod):
@@ -825,7 +828,7 @@ class TextualInversionFileConfig(TextualInversionConfigBase, ModelConfigBase):
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
if matches_overrides(
if _validate_overrides(
config_class=cls,
provided_overrides=fields,
valid_overrides=cls.VALID_OVERRIDES,
@@ -858,7 +861,7 @@ class TextualInversionFolderConfig(TextualInversionConfigBase, ModelConfigBase):
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
if matches_overrides(
if _validate_overrides(
config_class=cls,
provided_overrides=fields,
valid_overrides=cls.VALID_OVERRIDES,
@@ -988,7 +991,7 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
if matches_overrides(
if _validate_overrides(
config_class=cls,
provided_overrides=fields,
valid_overrides=cls.VALID_OVERRIDES,
@@ -1000,13 +1003,13 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
config_path = mod.path / "config.json"
raise_for_class_names(
_validate_class_names(
config_class=cls,
config_path=config_path,
valid_class_names=cls.VALID_CLASS_NAMES,
)
config = get_config_or_raise(cls, config_path)
config = _get_config_or_raise(cls, config_path)
clip_variant = cls.get_clip_variant_type(config)
@@ -1033,7 +1036,7 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
if matches_overrides(
if _validate_overrides(
config_class=cls,
provided_overrides=fields,
valid_overrides=cls.VALID_OVERRIDES,
@@ -1045,13 +1048,13 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
config_path = mod.path / "config.json"
raise_for_class_names(
_validate_class_names(
config_class=cls,
config_path=config_path,
valid_class_names=cls.VALID_CLASS_NAMES,
)
config = get_config_or_raise(cls, config_path)
config = _get_config_or_raise(cls, config_path)
clip_variant = cls.get_clip_variant_type(config)
if clip_variant is not ClipVariantType.L:
@@ -1089,7 +1092,7 @@ class SpandrelImageToImageConfig(ModelConfigBase):
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
if matches_overrides(
if _validate_overrides(
config_class=cls,
provided_overrides=fields,
valid_overrides=cls.VALID_OVERRIDES,
@@ -1147,7 +1150,7 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
if matches_overrides(
if _validate_overrides(
config_class=cls,
provided_overrides=fields,
valid_overrides=cls.VALID_OVERRIDES,
@@ -1159,7 +1162,7 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
config_path = mod.path / "config.json"
raise_for_class_names(
_validate_class_names(
config_class=cls,
config_path=config_path,
valid_class_names=cls.VALID_CLASS_NAMES,