mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
tidy(mm): skip optimistic override handling for now
This commit is contained in:
@@ -144,40 +144,26 @@ def _validate_overrides(
|
||||
config_class: type,
|
||||
provided_overrides: dict[str, Any],
|
||||
valid_overrides: dict[str, Any],
|
||||
) -> bool:
|
||||
) -> None:
|
||||
"""Check if the provided overrides match the valid overrides for this config class.
|
||||
|
||||
Args:
|
||||
config_class: The config class that is being tested.
|
||||
provided_overrides: The overrides provided by the user.
|
||||
valid_overrides: The overrides that are valid for this config class. The value can be a specific value or a
|
||||
callable that takes the provided value and returns True if it is valid.
|
||||
|
||||
Returns:
|
||||
True if all provided overrides match the valid overrides, False if some valid overrides are missing.
|
||||
valid_overrides: The overrides that are valid for this config class.
|
||||
|
||||
Raises:
|
||||
NotAMatch if any override does not match the allowed value.
|
||||
"""
|
||||
is_perfect_match = True
|
||||
for override_name, constraint in valid_overrides.items():
|
||||
if override_name not in provided_overrides:
|
||||
is_perfect_match = False
|
||||
for key, value in valid_overrides.items():
|
||||
if key not in provided_overrides:
|
||||
continue
|
||||
# Handle the typical case where the constraint is a specific value
|
||||
if provided_overrides[override_name] != constraint:
|
||||
if provided_overrides[key] != value:
|
||||
raise NotAMatch(
|
||||
config_class,
|
||||
f"override {override_name}={provided_overrides[override_name]} does not match required value {override_name}={constraint}",
|
||||
)
|
||||
# Handle the less common case where the constraint is a callable
|
||||
elif callable(constraint) and not constraint(provided_overrides[override_name]):
|
||||
raise NotAMatch(
|
||||
config_class,
|
||||
f"override {override_name}={provided_overrides[override_name]} does not match required value {override_name}=callable",
|
||||
f"override {key}={provided_overrides[key]} does not match required value {key}={value}",
|
||||
)
|
||||
|
||||
return is_perfect_match
|
||||
|
||||
|
||||
def _raise_if_not_file(
|
||||
@@ -422,8 +408,7 @@ class T5EncoderConfig(T5EncoderConfigBase, ModelConfigBase):
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_raise_if_not_dir(cls, mod)
|
||||
|
||||
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
|
||||
return cls(**fields)
|
||||
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
|
||||
|
||||
_validate_class_names(cls, mod.path / "text_encoder_2" / "config.json", cls.VALID_CLASS_NAMES)
|
||||
|
||||
@@ -452,8 +437,7 @@ class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, ModelConfigBase):
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_raise_if_not_dir(cls, mod)
|
||||
|
||||
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
|
||||
return cls(**fields)
|
||||
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
|
||||
|
||||
# Heuristic: Look for the T5EncoderModel class name in the config
|
||||
_validate_class_names(cls, mod.path / "text_encoder_2" / "config.json", cls.VALID_CLASS_NAMES)
|
||||
@@ -483,8 +467,7 @@ class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase):
|
||||
# OMI LoRAs are always files
|
||||
_raise_if_not_file(cls, mod)
|
||||
|
||||
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
|
||||
return cls(**fields)
|
||||
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
|
||||
|
||||
# Heuristic: differential diagnosis vs ControlLoRA and Diffusers
|
||||
if cls.flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
|
||||
@@ -532,8 +515,7 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_raise_if_not_file(cls, mod)
|
||||
|
||||
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
|
||||
return cls(**fields)
|
||||
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
|
||||
|
||||
# Heuristic: differential diagnosis vs ControlLoRA and Diffusers
|
||||
if cls.flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
|
||||
@@ -603,8 +585,7 @@ class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase):
|
||||
# Diffusers-style models always directories
|
||||
_raise_if_not_dir(cls, mod)
|
||||
|
||||
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
|
||||
return cls(**fields)
|
||||
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
|
||||
|
||||
is_flux_lora_diffusers = cls.flux_lora_format(mod) == FluxLoRAFormat.Diffusers
|
||||
|
||||
@@ -636,8 +617,7 @@ class VAECheckpointConfig(VAEConfigBase, CheckpointConfigBase, ModelConfigBase):
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_raise_if_not_file(cls, mod)
|
||||
|
||||
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
|
||||
return cls(**fields)
|
||||
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
|
||||
|
||||
if not mod.has_keys_starting_with({"encoder.conv_in", "decoder.conv_in"}):
|
||||
raise NotAMatch(cls, "model does not match Checkpoint VAE heuristics")
|
||||
@@ -678,8 +658,7 @@ class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase):
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_raise_if_not_dir(cls, mod)
|
||||
|
||||
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
|
||||
return cls(**fields)
|
||||
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
|
||||
|
||||
_validate_class_names(cls, mod.path / "config.json", cls.VALID_CLASS_NAMES)
|
||||
|
||||
@@ -813,8 +792,7 @@ class TextualInversionFileConfig(TextualInversionConfigBase, ModelConfigBase):
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_raise_if_not_file(cls, mod)
|
||||
|
||||
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
|
||||
return cls(**fields)
|
||||
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
|
||||
|
||||
if not cls._file_looks_like_embedding(mod):
|
||||
raise NotAMatch(cls, "model does not look like a textual inversion embedding file")
|
||||
@@ -841,8 +819,7 @@ class TextualInversionFolderConfig(TextualInversionConfigBase, ModelConfigBase):
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_raise_if_not_dir(cls, mod)
|
||||
|
||||
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
|
||||
return cls(**fields)
|
||||
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
|
||||
|
||||
for p in mod.weight_files():
|
||||
if cls._file_looks_like_embedding(mod, p):
|
||||
@@ -903,15 +880,6 @@ class IPAdapterConfigBase(ABC, BaseModel):
|
||||
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter
|
||||
|
||||
|
||||
IPAdapterInvokeAIConfigBaseTypes: TypeAlias = Literal[
|
||||
BaseModelType.StableDiffusion1,
|
||||
BaseModelType.StableDiffusion2,
|
||||
BaseModelType.StableDiffusionXL,
|
||||
]
|
||||
"""Helper TypeAlias for valid base types for IP Adapter models in the InvokeAI format."""
|
||||
|
||||
ip_adapter_invoke_ai_base_type_adapter = TypeAdapter[IPAdapterInvokeAIConfigBaseTypes](IPAdapterInvokeAIConfigBaseTypes)
|
||||
"""Helper TypeAdapter for IP Adapter InvokeAI base types."""
|
||||
|
||||
|
||||
class IPAdapterInvokeAIConfig(IPAdapterConfigBase, ModelConfigBase):
|
||||
@@ -921,20 +889,17 @@ class IPAdapterInvokeAIConfig(IPAdapterConfigBase, ModelConfigBase):
|
||||
# time. Need to go through the history to make sure I'm understanding this fully.
|
||||
image_encoder_model_id: str
|
||||
format: Literal[ModelFormat.InvokeAI] = ModelFormat.InvokeAI
|
||||
base: IPAdapterInvokeAIConfigBaseTypes = Field(...)
|
||||
|
||||
VALID_OVERRIDES: ClassVar = {
|
||||
"type": ModelType.IPAdapter,
|
||||
"format": ModelFormat.InvokeAI,
|
||||
"base": ip_adapter_invoke_ai_base_type_adapter.validate_python,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_raise_if_not_dir(cls, mod)
|
||||
|
||||
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
|
||||
return cls(**fields)
|
||||
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
|
||||
|
||||
weights_file = mod.path / "ip_adapter.bin"
|
||||
if not weights_file.exists():
|
||||
@@ -948,7 +913,7 @@ class IPAdapterInvokeAIConfig(IPAdapterConfigBase, ModelConfigBase):
|
||||
return cls(**fields, base=base)
|
||||
|
||||
@classmethod
|
||||
def _get_base_or_raise(cls, mod: ModelOnDisk) -> IPAdapterInvokeAIConfigBaseTypes:
|
||||
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
|
||||
state_dict = mod.load_state_dict()
|
||||
|
||||
try:
|
||||
@@ -967,18 +932,6 @@ class IPAdapterInvokeAIConfig(IPAdapterConfigBase, ModelConfigBase):
|
||||
raise NotAMatch(cls, f"unrecognized cross attention dimension {cross_attention_dim}")
|
||||
|
||||
|
||||
IPAdapterCheckpointConfigBaseTypes: TypeAlias = Literal[
|
||||
BaseModelType.StableDiffusion1,
|
||||
BaseModelType.StableDiffusion2,
|
||||
BaseModelType.StableDiffusionXL,
|
||||
BaseModelType.Flux,
|
||||
]
|
||||
"""Helper TypeAlias for valid base types for IP Adapter models in the Checkpoint format."""
|
||||
|
||||
ip_adapter_checkpoint_base_type_adapter = TypeAdapter[IPAdapterCheckpointConfigBaseTypes](
|
||||
IPAdapterCheckpointConfigBaseTypes
|
||||
)
|
||||
"""Helper TypeAdapter for IP Adapter Checkpoint base types."""
|
||||
|
||||
|
||||
class IPAdapterCheckpointConfig(IPAdapterConfigBase, ModelConfigBase):
|
||||
@@ -989,15 +942,13 @@ class IPAdapterCheckpointConfig(IPAdapterConfigBase, ModelConfigBase):
|
||||
VALID_OVERRIDES: ClassVar = {
|
||||
"type": ModelType.IPAdapter,
|
||||
"format": ModelFormat.Checkpoint,
|
||||
"base": ip_adapter_checkpoint_base_type_adapter.validate_python,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_raise_if_not_file(cls, mod)
|
||||
|
||||
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
|
||||
return cls(**fields)
|
||||
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
|
||||
|
||||
if not mod.has_keys_starting_with(
|
||||
{
|
||||
@@ -1013,7 +964,7 @@ class IPAdapterCheckpointConfig(IPAdapterConfigBase, ModelConfigBase):
|
||||
return cls(**fields, base=base)
|
||||
|
||||
@classmethod
|
||||
def _get_base_or_raise(cls, mod: ModelOnDisk) -> IPAdapterCheckpointConfigBaseTypes:
|
||||
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
|
||||
state_dict = mod.load_state_dict()
|
||||
|
||||
if is_state_dict_xlabs_ip_adapter(state_dict):
|
||||
@@ -1083,8 +1034,7 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_raise_if_not_dir(cls, mod)
|
||||
|
||||
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
|
||||
return cls(**fields)
|
||||
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
|
||||
|
||||
config_path = mod.path / "config.json"
|
||||
|
||||
@@ -1119,8 +1069,7 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_raise_if_not_dir(cls, mod)
|
||||
|
||||
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
|
||||
return cls(**fields)
|
||||
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
|
||||
|
||||
config_path = mod.path / "config.json"
|
||||
|
||||
@@ -1154,8 +1103,7 @@ class CLIPVisionDiffusersConfig(DiffusersConfigBase, ModelConfigBase):
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_raise_if_not_dir(cls, mod)
|
||||
|
||||
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
|
||||
return cls(**fields)
|
||||
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
|
||||
|
||||
config_path = mod.path / "config.json"
|
||||
|
||||
@@ -1188,8 +1136,7 @@ class SpandrelImageToImageConfig(ModelConfigBase):
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_raise_if_not_file(cls, mod)
|
||||
|
||||
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
|
||||
return cls(**fields)
|
||||
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
|
||||
|
||||
try:
|
||||
# It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were
|
||||
@@ -1226,8 +1173,7 @@ class SigLIPConfig(DiffusersConfigBase, ModelConfigBase):
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_raise_if_not_dir(cls, mod)
|
||||
|
||||
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
|
||||
return cls(**fields)
|
||||
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
|
||||
|
||||
config_path = mod.path / "config.json"
|
||||
|
||||
@@ -1251,8 +1197,7 @@ class FluxReduxConfig(ModelConfigBase):
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_raise_if_not_file(cls, mod)
|
||||
|
||||
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
|
||||
return cls(**fields)
|
||||
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
|
||||
|
||||
if not is_state_dict_likely_flux_redux(mod.load_state_dict()):
|
||||
raise NotAMatch(cls, "model does not match FLUX Tools Redux heuristics")
|
||||
@@ -1280,8 +1225,7 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_raise_if_not_dir(cls, mod)
|
||||
|
||||
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
|
||||
return cls(**fields)
|
||||
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
|
||||
|
||||
config_path = mod.path / "config.json"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user