tidy(mm): skip optimistic override handling for now

This commit is contained in:
psychedelicious
2025-09-24 23:16:14 +10:00
parent b74e0f6ca4
commit 6d96fa055a

View File

@@ -144,40 +144,26 @@ def _validate_overrides(
config_class: type,
provided_overrides: dict[str, Any],
valid_overrides: dict[str, Any],
) -> bool:
) -> None:
"""Check if the provided overrides match the valid overrides for this config class.
Args:
config_class: The config class that is being tested.
provided_overrides: The overrides provided by the user.
valid_overrides: The overrides that are valid for this config class. The value can be a specific value or a
callable that takes the provided value and returns True if it is valid.
Returns:
True if all provided overrides match the valid overrides, False if some valid overrides are missing.
valid_overrides: The overrides that are valid for this config class.
Raises:
NotAMatch if any override does not match the allowed value.
"""
is_perfect_match = True
for override_name, constraint in valid_overrides.items():
if override_name not in provided_overrides:
is_perfect_match = False
for key, value in valid_overrides.items():
if key not in provided_overrides:
continue
# Handle the typical case where the constraint is a specific value
if provided_overrides[override_name] != constraint:
if provided_overrides[key] != value:
raise NotAMatch(
config_class,
f"override {override_name}={provided_overrides[override_name]} does not match required value {override_name}={constraint}",
)
# Handle the less common case where the constraint is a callable
elif callable(constraint) and not constraint(provided_overrides[override_name]):
raise NotAMatch(
config_class,
f"override {override_name}={provided_overrides[override_name]} does not match required value {override_name}=callable",
f"override {key}={provided_overrides[key]} does not match required value {key}={value}",
)
return is_perfect_match
def _raise_if_not_file(
@@ -422,8 +408,7 @@ class T5EncoderConfig(T5EncoderConfigBase, ModelConfigBase):
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_raise_if_not_dir(cls, mod)
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
return cls(**fields)
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
_validate_class_names(cls, mod.path / "text_encoder_2" / "config.json", cls.VALID_CLASS_NAMES)
@@ -452,8 +437,7 @@ class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, ModelConfigBase):
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_raise_if_not_dir(cls, mod)
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
return cls(**fields)
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
# Heuristic: Look for the T5EncoderModel class name in the config
_validate_class_names(cls, mod.path / "text_encoder_2" / "config.json", cls.VALID_CLASS_NAMES)
@@ -483,8 +467,7 @@ class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase):
# OMI LoRAs are always files
_raise_if_not_file(cls, mod)
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
return cls(**fields)
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
# Heuristic: differential diagnosis vs ControlLoRA and Diffusers
if cls.flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
@@ -532,8 +515,7 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_raise_if_not_file(cls, mod)
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
return cls(**fields)
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
# Heuristic: differential diagnosis vs ControlLoRA and Diffusers
if cls.flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
@@ -603,8 +585,7 @@ class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase):
# Diffusers-style models always directories
_raise_if_not_dir(cls, mod)
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
return cls(**fields)
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
is_flux_lora_diffusers = cls.flux_lora_format(mod) == FluxLoRAFormat.Diffusers
@@ -636,8 +617,7 @@ class VAECheckpointConfig(VAEConfigBase, CheckpointConfigBase, ModelConfigBase):
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_raise_if_not_file(cls, mod)
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
return cls(**fields)
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
if not mod.has_keys_starting_with({"encoder.conv_in", "decoder.conv_in"}):
raise NotAMatch(cls, "model does not match Checkpoint VAE heuristics")
@@ -678,8 +658,7 @@ class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase):
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_raise_if_not_dir(cls, mod)
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
return cls(**fields)
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
_validate_class_names(cls, mod.path / "config.json", cls.VALID_CLASS_NAMES)
@@ -813,8 +792,7 @@ class TextualInversionFileConfig(TextualInversionConfigBase, ModelConfigBase):
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_raise_if_not_file(cls, mod)
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
return cls(**fields)
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
if not cls._file_looks_like_embedding(mod):
raise NotAMatch(cls, "model does not look like a textual inversion embedding file")
@@ -841,8 +819,7 @@ class TextualInversionFolderConfig(TextualInversionConfigBase, ModelConfigBase):
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_raise_if_not_dir(cls, mod)
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
return cls(**fields)
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
for p in mod.weight_files():
if cls._file_looks_like_embedding(mod, p):
@@ -903,15 +880,6 @@ class IPAdapterConfigBase(ABC, BaseModel):
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter
IPAdapterInvokeAIConfigBaseTypes: TypeAlias = Literal[
BaseModelType.StableDiffusion1,
BaseModelType.StableDiffusion2,
BaseModelType.StableDiffusionXL,
]
"""Helper TypeAlias for valid base types for IP Adapter models in the InvokeAI format."""
ip_adapter_invoke_ai_base_type_adapter = TypeAdapter[IPAdapterInvokeAIConfigBaseTypes](IPAdapterInvokeAIConfigBaseTypes)
"""Helper TypeAdapter for IP Adapter InvokeAI base types."""
class IPAdapterInvokeAIConfig(IPAdapterConfigBase, ModelConfigBase):
@@ -921,20 +889,17 @@ class IPAdapterInvokeAIConfig(IPAdapterConfigBase, ModelConfigBase):
# time. Need to go through the history to make sure I'm understanding this fully.
image_encoder_model_id: str
format: Literal[ModelFormat.InvokeAI] = ModelFormat.InvokeAI
base: IPAdapterInvokeAIConfigBaseTypes = Field(...)
VALID_OVERRIDES: ClassVar = {
"type": ModelType.IPAdapter,
"format": ModelFormat.InvokeAI,
"base": ip_adapter_invoke_ai_base_type_adapter.validate_python,
}
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_raise_if_not_dir(cls, mod)
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
return cls(**fields)
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
weights_file = mod.path / "ip_adapter.bin"
if not weights_file.exists():
@@ -948,7 +913,7 @@ class IPAdapterInvokeAIConfig(IPAdapterConfigBase, ModelConfigBase):
return cls(**fields, base=base)
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> IPAdapterInvokeAIConfigBaseTypes:
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
state_dict = mod.load_state_dict()
try:
@@ -967,18 +932,6 @@ class IPAdapterInvokeAIConfig(IPAdapterConfigBase, ModelConfigBase):
raise NotAMatch(cls, f"unrecognized cross attention dimension {cross_attention_dim}")
IPAdapterCheckpointConfigBaseTypes: TypeAlias = Literal[
BaseModelType.StableDiffusion1,
BaseModelType.StableDiffusion2,
BaseModelType.StableDiffusionXL,
BaseModelType.Flux,
]
"""Helper TypeAlias for valid base types for IP Adapter models in the Checkpoint format."""
ip_adapter_checkpoint_base_type_adapter = TypeAdapter[IPAdapterCheckpointConfigBaseTypes](
IPAdapterCheckpointConfigBaseTypes
)
"""Helper TypeAdapter for IP Adapter Checkpoint base types."""
class IPAdapterCheckpointConfig(IPAdapterConfigBase, ModelConfigBase):
@@ -989,15 +942,13 @@ class IPAdapterCheckpointConfig(IPAdapterConfigBase, ModelConfigBase):
VALID_OVERRIDES: ClassVar = {
"type": ModelType.IPAdapter,
"format": ModelFormat.Checkpoint,
"base": ip_adapter_checkpoint_base_type_adapter.validate_python,
}
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_raise_if_not_file(cls, mod)
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
return cls(**fields)
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
if not mod.has_keys_starting_with(
{
@@ -1013,7 +964,7 @@ class IPAdapterCheckpointConfig(IPAdapterConfigBase, ModelConfigBase):
return cls(**fields, base=base)
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> IPAdapterCheckpointConfigBaseTypes:
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
state_dict = mod.load_state_dict()
if is_state_dict_xlabs_ip_adapter(state_dict):
@@ -1083,8 +1034,7 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_raise_if_not_dir(cls, mod)
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
return cls(**fields)
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
config_path = mod.path / "config.json"
@@ -1119,8 +1069,7 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_raise_if_not_dir(cls, mod)
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
return cls(**fields)
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
config_path = mod.path / "config.json"
@@ -1154,8 +1103,7 @@ class CLIPVisionDiffusersConfig(DiffusersConfigBase, ModelConfigBase):
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_raise_if_not_dir(cls, mod)
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
return cls(**fields)
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
config_path = mod.path / "config.json"
@@ -1188,8 +1136,7 @@ class SpandrelImageToImageConfig(ModelConfigBase):
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_raise_if_not_file(cls, mod)
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
return cls(**fields)
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
try:
# It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were
@@ -1226,8 +1173,7 @@ class SigLIPConfig(DiffusersConfigBase, ModelConfigBase):
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_raise_if_not_dir(cls, mod)
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
return cls(**fields)
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
config_path = mod.path / "config.json"
@@ -1251,8 +1197,7 @@ class FluxReduxConfig(ModelConfigBase):
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_raise_if_not_file(cls, mod)
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
return cls(**fields)
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
if not is_state_dict_likely_flux_redux(mod.load_state_dict()):
raise NotAMatch(cls, "model does not match FLUX Tools Redux heuristics")
@@ -1280,8 +1225,7 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_raise_if_not_dir(cls, mod)
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
return cls(**fields)
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
config_path = mod.path / "config.json"