feat(mm): make match helpers more succint

This commit is contained in:
psychedelicious
2025-09-24 18:50:06 +10:00
parent d89472d3b1
commit 3b606b6d63

View File

@@ -170,6 +170,24 @@ def _validate_overrides(
return is_perfect_match
def _raise_if_not_file(
config_class: type,
mod: ModelOnDisk,
) -> None:
"""Raise NotAMatch if the model path is not a file."""
if not mod.path.is_file():
raise NotAMatch(config_class, "model path is not a file")
def _raise_if_not_dir(
config_class: type,
mod: ModelOnDisk,
) -> None:
"""Raise NotAMatch if the model path is not a directory."""
if not mod.path.is_dir():
raise NotAMatch(config_class, "model path is not a directory")
class SubmodelDefinition(BaseModel):
path_or_prefix: str
model_type: ModelType
@@ -392,21 +410,12 @@ class T5EncoderConfig(T5EncoderConfigBase, ModelConfigBase):
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
if _validate_overrides(
config_class=cls,
provided_overrides=fields,
valid_overrides=cls.VALID_OVERRIDES,
):
_raise_if_not_dir(cls, mod)
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
return cls(**fields)
if mod.path.is_file():
raise NotAMatch(cls, "model path is a file, not a directory")
_validate_class_names(
config_class=cls,
config_path=mod.path / "text_encoder_2" / "config.json",
valid_class_names=cls.VALID_CLASS_NAMES,
)
_validate_class_names(cls, mod.path / "text_encoder_2" / "config.json", cls.VALID_CLASS_NAMES)
# Heuristic: Look for the presence of the unquantized config file (not present for bnb-quantized models)
has_unquantized_config = (mod.path / "text_encoder_2" / "model.safetensors.index.json").exists()
@@ -431,19 +440,13 @@ class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, ModelConfigBase):
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
if _validate_overrides(
config_class=cls,
provided_overrides=fields,
valid_overrides=cls.VALID_OVERRIDES,
):
_raise_if_not_dir(cls, mod)
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
return cls(**fields)
# Heuristic: Look for the T5EncoderModel class name in the config
_validate_class_names(
config_class=cls,
config_path=mod.path / "text_encoder_2" / "config.json",
valid_class_names=cls.VALID_CLASS_NAMES,
)
_validate_class_names(cls, mod.path / "text_encoder_2" / "config.json", cls.VALID_CLASS_NAMES)
# Heuristic: look for the quantization in the filename name
filename_looks_like_bnb = any(x for x in mod.weight_files() if "llm_int8" in x.as_posix())
@@ -467,16 +470,11 @@ class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase):
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
if _validate_overrides(
config_class=cls,
provided_overrides=fields,
valid_overrides=cls.VALID_OVERRIDES,
):
return cls(**fields)
# OMI LoRAs are always files
_raise_if_not_file(cls, mod)
# Heuristic: OMI LoRAs are always files, never directories
if mod.path.is_dir():
raise NotAMatch(cls, "model path is a directory, not a file")
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
return cls(**fields)
# Heuristic: differential diagnosis vs ControlLoRA and Diffusers
if cls.flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
@@ -522,16 +520,11 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
if _validate_overrides(
config_class=cls,
provided_overrides=fields,
valid_overrides=cls.VALID_OVERRIDES,
):
return cls(**fields)
# LyCORIS LoRAs are always files, never directories
_raise_if_not_file(cls, mod)
# Heuristic: LyCORIS LoRAs are always files, never directories
if mod.path.is_dir():
raise NotAMatch(cls, "model path is a directory, not a file")
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
return cls(**fields)
# Heuristic: differential diagnosis vs ControlLoRA and Diffusers
if cls.flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
@@ -598,16 +591,11 @@ class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase):
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
if _validate_overrides(
config_class=cls,
provided_overrides=fields,
valid_overrides=cls.VALID_OVERRIDES,
):
return cls(**fields)
# Diffusers-style models always directories
_raise_if_not_dir(cls, mod)
# Heuristic: Diffusers LoRAs are always directories, never files
if mod.path.is_file():
raise NotAMatch(cls, "model path is a file, not a directory")
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
return cls(**fields)
is_flux_lora_diffusers = cls.flux_lora_format(mod) == FluxLoRAFormat.Diffusers
@@ -637,15 +625,10 @@ class VAECheckpointConfig(VAEConfigBase, CheckpointConfigBase, ModelConfigBase):
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
if _validate_overrides(
config_class=cls,
provided_overrides=fields,
valid_overrides=cls.VALID_OVERRIDES,
):
return cls(**fields)
_raise_if_not_file(cls, mod)
if mod.path.is_dir():
raise NotAMatch(cls, "model path is a directory, not a file")
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
return cls(**fields)
if not mod.has_keys_starting_with({"encoder.conv_in", "decoder.conv_in"}):
raise NotAMatch(cls, "model does not match Checkpoint VAE heuristics")
@@ -684,21 +667,12 @@ class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase):
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
if _validate_overrides(
config_class=cls,
provided_overrides=fields,
valid_overrides=cls.VALID_OVERRIDES,
):
_raise_if_not_dir(cls, mod)
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
return cls(**fields)
if mod.path.is_file():
raise NotAMatch(cls, "model path is a file, not a directory")
_validate_class_names(
config_class=cls,
config_path=mod.path / "config.json",
valid_class_names=cls.VALID_CLASS_NAMES,
)
_validate_class_names(cls, mod.path / "config.json", cls.VALID_CLASS_NAMES)
base = fields.get("base") or cls._get_base_or_raise(mod)
return cls(**fields, base=base)
@@ -828,15 +802,10 @@ class TextualInversionFileConfig(TextualInversionConfigBase, ModelConfigBase):
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
if _validate_overrides(
config_class=cls,
provided_overrides=fields,
valid_overrides=cls.VALID_OVERRIDES,
):
return cls(**fields)
_raise_if_not_file(cls, mod)
if mod.path.is_dir():
raise NotAMatch(cls, "model path is a directory, not a file")
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
return cls(**fields)
if not cls._file_looks_like_embedding(mod):
raise NotAMatch(cls, "model does not look like a textual inversion embedding file")
@@ -861,15 +830,10 @@ class TextualInversionFolderConfig(TextualInversionConfigBase, ModelConfigBase):
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
if _validate_overrides(
config_class=cls,
provided_overrides=fields,
valid_overrides=cls.VALID_OVERRIDES,
):
return cls(**fields)
_raise_if_not_dir(cls, mod)
if mod.path.is_file():
raise NotAMatch(cls, "model path is a file, not a directory")
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
return cls(**fields)
for p in mod.weight_files():
if cls._file_looks_like_embedding(mod, p):
@@ -991,23 +955,14 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
if _validate_overrides(
config_class=cls,
provided_overrides=fields,
valid_overrides=cls.VALID_OVERRIDES,
):
return cls(**fields)
_raise_if_not_dir(cls, mod)
if mod.path.is_file():
raise NotAMatch(cls, "model path is a file, not a directory")
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
return cls(**fields)
config_path = mod.path / "config.json"
_validate_class_names(
config_class=cls,
config_path=config_path,
valid_class_names=cls.VALID_CLASS_NAMES,
)
_validate_class_names(cls, config_path, cls.VALID_CLASS_NAMES)
config = _get_config_or_raise(cls, config_path)
@@ -1036,23 +991,14 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
if _validate_overrides(
config_class=cls,
provided_overrides=fields,
valid_overrides=cls.VALID_OVERRIDES,
):
return cls(**fields)
_raise_if_not_dir(cls, mod)
if mod.path.is_file():
raise NotAMatch(cls, "model path is a file, not a directory")
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
return cls(**fields)
config_path = mod.path / "config.json"
_validate_class_names(
config_class=cls,
config_path=config_path,
valid_class_names=cls.VALID_CLASS_NAMES,
)
_validate_class_names(cls, config_path, cls.VALID_CLASS_NAMES)
config = _get_config_or_raise(cls, config_path)
clip_variant = cls.get_clip_variant_type(config)
@@ -1080,23 +1026,14 @@ class CLIPVisionDiffusersConfig(DiffusersConfigBase, ModelConfigBase):
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
if _validate_overrides(
config_class=cls,
provided_overrides=fields,
valid_overrides=cls.VALID_OVERRIDES,
):
return cls(**fields)
_raise_if_not_dir(cls, mod)
if mod.path.is_file():
raise NotAMatch(cls, "model path is a file, not a directory")
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
return cls(**fields)
config_path = mod.path / "config.json"
_validate_class_names(
config_class=cls,
config_path=config_path,
valid_class_names=cls.VALID_CLASS_NAMES,
)
_validate_class_names(cls, config_path, cls.VALID_CLASS_NAMES)
return cls(**fields)
@@ -1123,15 +1060,10 @@ class SpandrelImageToImageConfig(ModelConfigBase):
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
if _validate_overrides(
config_class=cls,
provided_overrides=fields,
valid_overrides=cls.VALID_OVERRIDES,
):
return cls(**fields)
_raise_if_not_file(cls, mod)
if not mod.path.is_file():
raise NotAMatch(cls, "model path is a directory, not a file")
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
return cls(**fields)
try:
# It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were
@@ -1166,33 +1098,38 @@ class SigLIPConfig(DiffusersConfigBase, ModelConfigBase):
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
if _validate_overrides(
config_class=cls,
provided_overrides=fields,
valid_overrides=cls.VALID_OVERRIDES,
):
return cls(**fields)
_raise_if_not_dir(cls, mod)
if mod.path.is_file():
raise NotAMatch(cls, "model path is a file, not a directory")
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
return cls(**fields)
config_path = mod.path / "config.json"
_validate_class_names(
config_class=cls,
config_path=config_path,
valid_class_names=cls.VALID_CLASS_NAMES,
)
_validate_class_names(cls, config_path, cls.VALID_CLASS_NAMES)
return cls(**fields)
class FluxReduxConfig(LegacyProbeMixin, ModelConfigBase):
class FluxReduxConfig(ModelConfigBase):
"""Model config for FLUX Tools Redux model."""
type: Literal[ModelType.FluxRedux] = ModelType.FluxRedux
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
VALID_OVERRIDES: ClassVar = {
"type": ModelType.SigLIP,
"format": ModelFormat.Diffusers,
}
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_raise_if_not_file(cls, mod)
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
return cls(**fields)
return cls(**fields)
class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
"""Model config for Llava Onevision models."""
@@ -1212,23 +1149,14 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
if _validate_overrides(
config_class=cls,
provided_overrides=fields,
valid_overrides=cls.VALID_OVERRIDES,
):
return cls(**fields)
_raise_if_not_dir(cls, mod)
if mod.path.is_file():
raise NotAMatch(cls, "model path is a file, not a directory")
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
return cls(**fields)
config_path = mod.path / "config.json"
_validate_class_names(
config_class=cls,
config_path=config_path,
valid_class_names=cls.VALID_CLASS_NAMES,
)
_validate_class_names(cls, config_path, cls.VALID_CLASS_NAMES)
return cls(**fields)