mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(mm): make match helpers more succint
This commit is contained in:
@@ -170,6 +170,24 @@ def _validate_overrides(
|
||||
return is_perfect_match
|
||||
|
||||
|
||||
def _raise_if_not_file(
|
||||
config_class: type,
|
||||
mod: ModelOnDisk,
|
||||
) -> None:
|
||||
"""Raise NotAMatch if the model path is not a file."""
|
||||
if not mod.path.is_file():
|
||||
raise NotAMatch(config_class, "model path is not a file")
|
||||
|
||||
|
||||
def _raise_if_not_dir(
|
||||
config_class: type,
|
||||
mod: ModelOnDisk,
|
||||
) -> None:
|
||||
"""Raise NotAMatch if the model path is not a directory."""
|
||||
if not mod.path.is_dir():
|
||||
raise NotAMatch(config_class, "model path is not a directory")
|
||||
|
||||
|
||||
class SubmodelDefinition(BaseModel):
|
||||
path_or_prefix: str
|
||||
model_type: ModelType
|
||||
@@ -392,21 +410,12 @@ class T5EncoderConfig(T5EncoderConfigBase, ModelConfigBase):
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
if _validate_overrides(
|
||||
config_class=cls,
|
||||
provided_overrides=fields,
|
||||
valid_overrides=cls.VALID_OVERRIDES,
|
||||
):
|
||||
_raise_if_not_dir(cls, mod)
|
||||
|
||||
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
|
||||
return cls(**fields)
|
||||
|
||||
if mod.path.is_file():
|
||||
raise NotAMatch(cls, "model path is a file, not a directory")
|
||||
|
||||
_validate_class_names(
|
||||
config_class=cls,
|
||||
config_path=mod.path / "text_encoder_2" / "config.json",
|
||||
valid_class_names=cls.VALID_CLASS_NAMES,
|
||||
)
|
||||
_validate_class_names(cls, mod.path / "text_encoder_2" / "config.json", cls.VALID_CLASS_NAMES)
|
||||
|
||||
# Heuristic: Look for the presence of the unquantized config file (not present for bnb-quantized models)
|
||||
has_unquantized_config = (mod.path / "text_encoder_2" / "model.safetensors.index.json").exists()
|
||||
@@ -431,19 +440,13 @@ class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, ModelConfigBase):
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
if _validate_overrides(
|
||||
config_class=cls,
|
||||
provided_overrides=fields,
|
||||
valid_overrides=cls.VALID_OVERRIDES,
|
||||
):
|
||||
_raise_if_not_dir(cls, mod)
|
||||
|
||||
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
|
||||
return cls(**fields)
|
||||
|
||||
# Heuristic: Look for the T5EncoderModel class name in the config
|
||||
_validate_class_names(
|
||||
config_class=cls,
|
||||
config_path=mod.path / "text_encoder_2" / "config.json",
|
||||
valid_class_names=cls.VALID_CLASS_NAMES,
|
||||
)
|
||||
_validate_class_names(cls, mod.path / "text_encoder_2" / "config.json", cls.VALID_CLASS_NAMES)
|
||||
|
||||
# Heuristic: look for the quantization in the filename name
|
||||
filename_looks_like_bnb = any(x for x in mod.weight_files() if "llm_int8" in x.as_posix())
|
||||
@@ -467,16 +470,11 @@ class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase):
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
if _validate_overrides(
|
||||
config_class=cls,
|
||||
provided_overrides=fields,
|
||||
valid_overrides=cls.VALID_OVERRIDES,
|
||||
):
|
||||
return cls(**fields)
|
||||
# OMI LoRAs are always files
|
||||
_raise_if_not_file(cls, mod)
|
||||
|
||||
# Heuristic: OMI LoRAs are always files, never directories
|
||||
if mod.path.is_dir():
|
||||
raise NotAMatch(cls, "model path is a directory, not a file")
|
||||
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
|
||||
return cls(**fields)
|
||||
|
||||
# Heuristic: differential diagnosis vs ControlLoRA and Diffusers
|
||||
if cls.flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
|
||||
@@ -522,16 +520,11 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
if _validate_overrides(
|
||||
config_class=cls,
|
||||
provided_overrides=fields,
|
||||
valid_overrides=cls.VALID_OVERRIDES,
|
||||
):
|
||||
return cls(**fields)
|
||||
# LyCORIS LoRAs are always files, never directories
|
||||
_raise_if_not_file(cls, mod)
|
||||
|
||||
# Heuristic: LyCORIS LoRAs are always files, never directories
|
||||
if mod.path.is_dir():
|
||||
raise NotAMatch(cls, "model path is a directory, not a file")
|
||||
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
|
||||
return cls(**fields)
|
||||
|
||||
# Heuristic: differential diagnosis vs ControlLoRA and Diffusers
|
||||
if cls.flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
|
||||
@@ -598,16 +591,11 @@ class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase):
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
if _validate_overrides(
|
||||
config_class=cls,
|
||||
provided_overrides=fields,
|
||||
valid_overrides=cls.VALID_OVERRIDES,
|
||||
):
|
||||
return cls(**fields)
|
||||
# Diffusers-style models always directories
|
||||
_raise_if_not_dir(cls, mod)
|
||||
|
||||
# Heuristic: Diffusers LoRAs are always directories, never files
|
||||
if mod.path.is_file():
|
||||
raise NotAMatch(cls, "model path is a file, not a directory")
|
||||
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
|
||||
return cls(**fields)
|
||||
|
||||
is_flux_lora_diffusers = cls.flux_lora_format(mod) == FluxLoRAFormat.Diffusers
|
||||
|
||||
@@ -637,15 +625,10 @@ class VAECheckpointConfig(VAEConfigBase, CheckpointConfigBase, ModelConfigBase):
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
if _validate_overrides(
|
||||
config_class=cls,
|
||||
provided_overrides=fields,
|
||||
valid_overrides=cls.VALID_OVERRIDES,
|
||||
):
|
||||
return cls(**fields)
|
||||
_raise_if_not_file(cls, mod)
|
||||
|
||||
if mod.path.is_dir():
|
||||
raise NotAMatch(cls, "model path is a directory, not a file")
|
||||
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
|
||||
return cls(**fields)
|
||||
|
||||
if not mod.has_keys_starting_with({"encoder.conv_in", "decoder.conv_in"}):
|
||||
raise NotAMatch(cls, "model does not match Checkpoint VAE heuristics")
|
||||
@@ -684,21 +667,12 @@ class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase):
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
if _validate_overrides(
|
||||
config_class=cls,
|
||||
provided_overrides=fields,
|
||||
valid_overrides=cls.VALID_OVERRIDES,
|
||||
):
|
||||
_raise_if_not_dir(cls, mod)
|
||||
|
||||
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
|
||||
return cls(**fields)
|
||||
|
||||
if mod.path.is_file():
|
||||
raise NotAMatch(cls, "model path is a file, not a directory")
|
||||
|
||||
_validate_class_names(
|
||||
config_class=cls,
|
||||
config_path=mod.path / "config.json",
|
||||
valid_class_names=cls.VALID_CLASS_NAMES,
|
||||
)
|
||||
_validate_class_names(cls, mod.path / "config.json", cls.VALID_CLASS_NAMES)
|
||||
|
||||
base = fields.get("base") or cls._get_base_or_raise(mod)
|
||||
return cls(**fields, base=base)
|
||||
@@ -828,15 +802,10 @@ class TextualInversionFileConfig(TextualInversionConfigBase, ModelConfigBase):
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
if _validate_overrides(
|
||||
config_class=cls,
|
||||
provided_overrides=fields,
|
||||
valid_overrides=cls.VALID_OVERRIDES,
|
||||
):
|
||||
return cls(**fields)
|
||||
_raise_if_not_file(cls, mod)
|
||||
|
||||
if mod.path.is_dir():
|
||||
raise NotAMatch(cls, "model path is a directory, not a file")
|
||||
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
|
||||
return cls(**fields)
|
||||
|
||||
if not cls._file_looks_like_embedding(mod):
|
||||
raise NotAMatch(cls, "model does not look like a textual inversion embedding file")
|
||||
@@ -861,15 +830,10 @@ class TextualInversionFolderConfig(TextualInversionConfigBase, ModelConfigBase):
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
if _validate_overrides(
|
||||
config_class=cls,
|
||||
provided_overrides=fields,
|
||||
valid_overrides=cls.VALID_OVERRIDES,
|
||||
):
|
||||
return cls(**fields)
|
||||
_raise_if_not_dir(cls, mod)
|
||||
|
||||
if mod.path.is_file():
|
||||
raise NotAMatch(cls, "model path is a file, not a directory")
|
||||
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
|
||||
return cls(**fields)
|
||||
|
||||
for p in mod.weight_files():
|
||||
if cls._file_looks_like_embedding(mod, p):
|
||||
@@ -991,23 +955,14 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
if _validate_overrides(
|
||||
config_class=cls,
|
||||
provided_overrides=fields,
|
||||
valid_overrides=cls.VALID_OVERRIDES,
|
||||
):
|
||||
return cls(**fields)
|
||||
_raise_if_not_dir(cls, mod)
|
||||
|
||||
if mod.path.is_file():
|
||||
raise NotAMatch(cls, "model path is a file, not a directory")
|
||||
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
|
||||
return cls(**fields)
|
||||
|
||||
config_path = mod.path / "config.json"
|
||||
|
||||
_validate_class_names(
|
||||
config_class=cls,
|
||||
config_path=config_path,
|
||||
valid_class_names=cls.VALID_CLASS_NAMES,
|
||||
)
|
||||
_validate_class_names(cls, config_path, cls.VALID_CLASS_NAMES)
|
||||
|
||||
config = _get_config_or_raise(cls, config_path)
|
||||
|
||||
@@ -1036,23 +991,14 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
if _validate_overrides(
|
||||
config_class=cls,
|
||||
provided_overrides=fields,
|
||||
valid_overrides=cls.VALID_OVERRIDES,
|
||||
):
|
||||
return cls(**fields)
|
||||
_raise_if_not_dir(cls, mod)
|
||||
|
||||
if mod.path.is_file():
|
||||
raise NotAMatch(cls, "model path is a file, not a directory")
|
||||
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
|
||||
return cls(**fields)
|
||||
|
||||
config_path = mod.path / "config.json"
|
||||
|
||||
_validate_class_names(
|
||||
config_class=cls,
|
||||
config_path=config_path,
|
||||
valid_class_names=cls.VALID_CLASS_NAMES,
|
||||
)
|
||||
_validate_class_names(cls, config_path, cls.VALID_CLASS_NAMES)
|
||||
|
||||
config = _get_config_or_raise(cls, config_path)
|
||||
clip_variant = cls.get_clip_variant_type(config)
|
||||
@@ -1080,23 +1026,14 @@ class CLIPVisionDiffusersConfig(DiffusersConfigBase, ModelConfigBase):
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
if _validate_overrides(
|
||||
config_class=cls,
|
||||
provided_overrides=fields,
|
||||
valid_overrides=cls.VALID_OVERRIDES,
|
||||
):
|
||||
return cls(**fields)
|
||||
_raise_if_not_dir(cls, mod)
|
||||
|
||||
if mod.path.is_file():
|
||||
raise NotAMatch(cls, "model path is a file, not a directory")
|
||||
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
|
||||
return cls(**fields)
|
||||
|
||||
config_path = mod.path / "config.json"
|
||||
|
||||
_validate_class_names(
|
||||
config_class=cls,
|
||||
config_path=config_path,
|
||||
valid_class_names=cls.VALID_CLASS_NAMES,
|
||||
)
|
||||
_validate_class_names(cls, config_path, cls.VALID_CLASS_NAMES)
|
||||
|
||||
return cls(**fields)
|
||||
|
||||
@@ -1123,15 +1060,10 @@ class SpandrelImageToImageConfig(ModelConfigBase):
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
if _validate_overrides(
|
||||
config_class=cls,
|
||||
provided_overrides=fields,
|
||||
valid_overrides=cls.VALID_OVERRIDES,
|
||||
):
|
||||
return cls(**fields)
|
||||
_raise_if_not_file(cls, mod)
|
||||
|
||||
if not mod.path.is_file():
|
||||
raise NotAMatch(cls, "model path is a directory, not a file")
|
||||
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
|
||||
return cls(**fields)
|
||||
|
||||
try:
|
||||
# It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were
|
||||
@@ -1166,33 +1098,38 @@ class SigLIPConfig(DiffusersConfigBase, ModelConfigBase):
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
if _validate_overrides(
|
||||
config_class=cls,
|
||||
provided_overrides=fields,
|
||||
valid_overrides=cls.VALID_OVERRIDES,
|
||||
):
|
||||
return cls(**fields)
|
||||
_raise_if_not_dir(cls, mod)
|
||||
|
||||
if mod.path.is_file():
|
||||
raise NotAMatch(cls, "model path is a file, not a directory")
|
||||
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
|
||||
return cls(**fields)
|
||||
|
||||
config_path = mod.path / "config.json"
|
||||
|
||||
_validate_class_names(
|
||||
config_class=cls,
|
||||
config_path=config_path,
|
||||
valid_class_names=cls.VALID_CLASS_NAMES,
|
||||
)
|
||||
_validate_class_names(cls, config_path, cls.VALID_CLASS_NAMES)
|
||||
|
||||
return cls(**fields)
|
||||
|
||||
|
||||
class FluxReduxConfig(LegacyProbeMixin, ModelConfigBase):
|
||||
class FluxReduxConfig(ModelConfigBase):
|
||||
"""Model config for FLUX Tools Redux model."""
|
||||
|
||||
type: Literal[ModelType.FluxRedux] = ModelType.FluxRedux
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
VALID_OVERRIDES: ClassVar = {
|
||||
"type": ModelType.SigLIP,
|
||||
"format": ModelFormat.Diffusers,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_raise_if_not_file(cls, mod)
|
||||
|
||||
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
|
||||
return cls(**fields)
|
||||
|
||||
return cls(**fields)
|
||||
|
||||
|
||||
class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
|
||||
"""Model config for Llava Onevision models."""
|
||||
@@ -1212,23 +1149,14 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
if _validate_overrides(
|
||||
config_class=cls,
|
||||
provided_overrides=fields,
|
||||
valid_overrides=cls.VALID_OVERRIDES,
|
||||
):
|
||||
return cls(**fields)
|
||||
_raise_if_not_dir(cls, mod)
|
||||
|
||||
if mod.path.is_file():
|
||||
raise NotAMatch(cls, "model path is a file, not a directory")
|
||||
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
|
||||
return cls(**fields)
|
||||
|
||||
config_path = mod.path / "config.json"
|
||||
|
||||
_validate_class_names(
|
||||
config_class=cls,
|
||||
config_path=config_path,
|
||||
valid_class_names=cls.VALID_CLASS_NAMES,
|
||||
)
|
||||
_validate_class_names(cls, config_path, cls.VALID_CLASS_NAMES)
|
||||
|
||||
return cls(**fields)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user