mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
tidy(mm): clean up model heuristic utils
This commit is contained in:
@@ -106,35 +106,67 @@ class NotAMatch(Exception):
|
||||
DEFAULTS_PRECISION = Literal["fp16", "fp32"]
|
||||
|
||||
|
||||
# Utility from https://github.com/pydantic/pydantic/discussions/7367#discussioncomment-14213144
|
||||
def find_field_schema(model: type[BaseModel], field_name: str) -> CoreSchema:
|
||||
schema: CoreSchema = model.__pydantic_core_schema__.copy()
|
||||
# we shallow copied, be careful not to mutate the original schema!
|
||||
class FieldValidator:
|
||||
"""Utility class for validating individual fields of a Pydantic model without instantiating the whole model.
|
||||
|
||||
assert schema["type"] in ["definitions", "model"]
|
||||
See: https://github.com/pydantic/pydantic/discussions/7367#discussioncomment-14213144
|
||||
"""
|
||||
|
||||
# find the field schema
|
||||
field_schema = schema["schema"] # type: ignore
|
||||
while "fields" not in field_schema:
|
||||
field_schema = field_schema["schema"] # type: ignore
|
||||
@staticmethod
|
||||
def find_field_schema(model: type[BaseModel], field_name: str) -> CoreSchema:
|
||||
"""Find the Pydantic core schema for a specific field in a model."""
|
||||
schema: CoreSchema = model.__pydantic_core_schema__.copy()
|
||||
# we shallow copied, be careful not to mutate the original schema!
|
||||
|
||||
field_schema = field_schema["fields"][field_name]["schema"] # type: ignore
|
||||
assert schema["type"] in ["definitions", "model"]
|
||||
|
||||
# if the original schema is a definition schema, replace the model schema with the field schema
|
||||
if schema["type"] == "definitions":
|
||||
schema["schema"] = field_schema
|
||||
return schema
|
||||
else:
|
||||
return field_schema
|
||||
# find the field schema
|
||||
field_schema = schema["schema"] # type: ignore
|
||||
while "fields" not in field_schema:
|
||||
field_schema = field_schema["schema"] # type: ignore
|
||||
|
||||
field_schema = field_schema["fields"][field_name]["schema"] # type: ignore
|
||||
|
||||
# if the original schema is a definition schema, replace the model schema with the field schema
|
||||
if schema["type"] == "definitions":
|
||||
schema["schema"] = field_schema
|
||||
return schema
|
||||
else:
|
||||
return field_schema
|
||||
|
||||
@cache
|
||||
@staticmethod
|
||||
def get_validator(model: type[BaseModel], field_name: str) -> SchemaValidator:
|
||||
"""Get a SchemaValidator for a specific field in a model."""
|
||||
return SchemaValidator(FieldValidator.find_field_schema(model, field_name))
|
||||
|
||||
@staticmethod
|
||||
def validate_field(model: type[BaseModel], field_name: str, value: Any) -> Any:
|
||||
"""Validate a value for a specific field in a model."""
|
||||
return FieldValidator.get_validator(model, field_name).validate_python(value)
|
||||
|
||||
|
||||
@cache
|
||||
def validator(model: type[BaseModel], field_name: str) -> SchemaValidator:
|
||||
return SchemaValidator(find_field_schema(model, field_name))
|
||||
def has_keys_exact(state_dict: dict[str | int, Any], keys: str | set[str]) -> bool:
|
||||
"""Returns true if the state dict has all of the specified keys."""
|
||||
_keys = {keys} if isinstance(keys, str) else keys
|
||||
return _keys.issubset({key for key in state_dict.keys() if isinstance(key, str)})
|
||||
|
||||
|
||||
def validate_model_field(model: type[BaseModel], field_name: str, value: Any) -> Any:
|
||||
return validator(model, field_name).validate_python(value)
|
||||
def has_keys_starting_with(state_dict: dict[str | int, Any], prefixes: str | set[str]) -> bool:
|
||||
"""Returns true if the state dict has any keys starting with any of the specified prefixes."""
|
||||
_prefixes = {prefixes} if isinstance(prefixes, str) else prefixes
|
||||
return any(any(key.startswith(prefix) for prefix in _prefixes) for key in state_dict.keys() if isinstance(key, str))
|
||||
|
||||
|
||||
def has_keys_ending_with(state_dict: dict[str | int, Any], suffixes: str | set[str]) -> bool:
|
||||
"""Returns true if the state dict has any keys ending with any of the specified suffixes."""
|
||||
_suffixes = {suffixes} if isinstance(suffixes, str) else suffixes
|
||||
return any(any(key.endswith(suffix) for suffix in _suffixes) for key in state_dict.keys() if isinstance(key, str))
|
||||
|
||||
|
||||
def common_config_paths(path: Path) -> set[Path]:
|
||||
"""Returns common config file paths for models stored in directories."""
|
||||
return {path / "config.json", path / "model_index.json"}
|
||||
|
||||
|
||||
# These utility functions are tightly coupled to the config classes below in order to make the process of raising
|
||||
@@ -225,7 +257,7 @@ def _validate_override_fields(
|
||||
if field_name not in config_class.model_fields:
|
||||
raise NotAMatch(config_class, f"unknown override field: {field_name}")
|
||||
try:
|
||||
validate_model_field(config_class, field_name, override_value)
|
||||
FieldValidator.validate_field(config_class, field_name, override_value)
|
||||
except ValidationError as e:
|
||||
raise NotAMatch(config_class, f"invalid override for field '{field_name}': {e}") from e
|
||||
|
||||
@@ -440,7 +472,13 @@ class T5EncoderConfig(ModelConfigBase):
|
||||
|
||||
_validate_override_fields(cls, fields)
|
||||
|
||||
_validate_class_name(cls, mod.common_config_paths(), {"T5EncoderModel"})
|
||||
_validate_class_name(
|
||||
cls,
|
||||
common_config_paths(mod.path),
|
||||
{
|
||||
"T5EncoderModel",
|
||||
},
|
||||
)
|
||||
|
||||
cls._validate_has_unquantized_config_file(mod)
|
||||
|
||||
@@ -465,7 +503,13 @@ class T5EncoderBnbQuantizedLlmInt8bConfig(ModelConfigBase):
|
||||
|
||||
_validate_override_fields(cls, fields)
|
||||
|
||||
_validate_class_name(cls, mod.common_config_paths(), {"T5EncoderModel"})
|
||||
_validate_class_name(
|
||||
cls,
|
||||
common_config_paths(mod.path),
|
||||
{
|
||||
"T5EncoderModel",
|
||||
},
|
||||
)
|
||||
|
||||
cls._validate_filename_looks_like_bnb_quantized(mod)
|
||||
|
||||
@@ -481,7 +525,7 @@ class T5EncoderBnbQuantizedLlmInt8bConfig(ModelConfigBase):
|
||||
|
||||
@classmethod
|
||||
def _validate_model_looks_like_bnb_quantized(cls, mod: ModelOnDisk) -> None:
|
||||
has_scb_key_suffix = mod.has_keys_ending_with("SCB")
|
||||
has_scb_key_suffix = has_keys_ending_with(mod.load_state_dict(), "SCB")
|
||||
if not has_scb_key_suffix:
|
||||
raise NotAMatch(cls, "state dict does not look like bnb quantized llm_int8")
|
||||
|
||||
@@ -592,23 +636,25 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
|
||||
def _validate_looks_like_lora(cls, mod: ModelOnDisk) -> None:
|
||||
# Note: Existence of these key prefixes/suffixes does not guarantee that this is a LoRA.
|
||||
# Some main models have these keys, likely due to the creator merging in a LoRA.
|
||||
has_key_with_lora_prefix = mod.has_keys_starting_with(
|
||||
has_key_with_lora_prefix = has_keys_starting_with(
|
||||
mod.load_state_dict(),
|
||||
{
|
||||
"lora_te_",
|
||||
"lora_unet_",
|
||||
"lora_te1_",
|
||||
"lora_te2_",
|
||||
"lora_transformer_",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
has_key_with_lora_suffix = mod.has_keys_ending_with(
|
||||
has_key_with_lora_suffix = has_keys_ending_with(
|
||||
mod.load_state_dict(),
|
||||
{
|
||||
"to_k_lora.up.weight",
|
||||
"to_q_lora.down.weight",
|
||||
"lora_A.weight",
|
||||
"lora_B.weight",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
if not has_key_with_lora_prefix and not has_key_with_lora_suffix:
|
||||
@@ -754,7 +800,13 @@ class VAECheckpointConfig(CheckpointConfigBase, ModelConfigBase):
|
||||
|
||||
@classmethod
|
||||
def _validate_looks_like_vae(cls, mod: ModelOnDisk) -> None:
|
||||
if not mod.has_keys_starting_with({"encoder.conv_in", "decoder.conv_in"}):
|
||||
if not has_keys_starting_with(
|
||||
mod.load_state_dict(),
|
||||
{
|
||||
"encoder.conv_in",
|
||||
"decoder.conv_in",
|
||||
},
|
||||
):
|
||||
raise NotAMatch(cls, "model does not match Checkpoint VAE heuristics")
|
||||
|
||||
@classmethod
|
||||
@@ -786,7 +838,14 @@ class VAEDiffusersConfig(DiffusersConfigBase, ModelConfigBase):
|
||||
|
||||
_validate_override_fields(cls, fields)
|
||||
|
||||
_validate_class_name(cls, mod.common_config_paths(), {"AutoencoderKL", "AutoencoderTiny"})
|
||||
_validate_class_name(
|
||||
cls,
|
||||
common_config_paths(mod.path),
|
||||
{
|
||||
"AutoencoderKL",
|
||||
"AutoencoderTiny",
|
||||
},
|
||||
)
|
||||
|
||||
base = fields.get("base") or cls._get_base_or_raise(mod)
|
||||
return cls(**fields, base=base)
|
||||
@@ -812,7 +871,7 @@ class VAEDiffusersConfig(DiffusersConfigBase, ModelConfigBase):
|
||||
|
||||
@classmethod
|
||||
def _get_base_or_raise(cls, mod: ModelOnDisk) -> VAEDiffusersConfig_SupportedBases:
|
||||
config = _get_config_or_raise(cls, mod.common_config_paths())
|
||||
config = _get_config_or_raise(cls, common_config_paths(mod.path))
|
||||
if cls._config_looks_like_sdxl(config):
|
||||
return BaseModelType.StableDiffusionXL
|
||||
elif cls._name_looks_like_sdxl(mod):
|
||||
@@ -843,7 +902,14 @@ class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, M
|
||||
|
||||
_validate_override_fields(cls, fields)
|
||||
|
||||
_validate_class_name(cls, mod.common_config_paths(), {"ControlNetModel", "FluxControlNetModel"})
|
||||
_validate_class_name(
|
||||
cls,
|
||||
common_config_paths(mod.path),
|
||||
{
|
||||
"ControlNetModel",
|
||||
"FluxControlNetModel",
|
||||
},
|
||||
)
|
||||
|
||||
base = fields.get("base") or cls._get_base_or_raise(mod)
|
||||
|
||||
@@ -851,7 +917,7 @@ class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, M
|
||||
|
||||
@classmethod
|
||||
def _get_base_or_raise(cls, mod: ModelOnDisk) -> ControlNetDiffusers_SupportedBases:
|
||||
config = _get_config_or_raise(cls, mod.common_config_paths())
|
||||
config = _get_config_or_raise(cls, common_config_paths(mod.path))
|
||||
|
||||
if config.get("_class_name") == "FluxControlNetModel":
|
||||
return BaseModelType.Flux
|
||||
@@ -900,7 +966,8 @@ class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase,
|
||||
|
||||
@classmethod
|
||||
def _validate_looks_like_controlnet(cls, mod: ModelOnDisk) -> None:
|
||||
if not mod.has_keys_starting_with(
|
||||
if has_keys_starting_with(
|
||||
mod.load_state_dict(),
|
||||
{
|
||||
"controlnet",
|
||||
"control_model",
|
||||
@@ -911,7 +978,7 @@ class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase,
|
||||
# "double_blocks.", which we check for above. But, I'm afraid to modify this logic because it is so
|
||||
# delicate.
|
||||
"controlnet_blocks",
|
||||
}
|
||||
},
|
||||
):
|
||||
raise NotAMatch(cls, "state dict does not look like a ControlNet checkpoint")
|
||||
|
||||
@@ -1268,7 +1335,8 @@ class FLUX_Unquantized_CheckpointConfig(CheckpointConfigBase, MainConfigBase, Mo
|
||||
|
||||
@classmethod
|
||||
def _validate_is_flux(cls, mod: ModelOnDisk) -> None:
|
||||
if not mod.has_keys_exact(
|
||||
if not has_keys_exact(
|
||||
mod.load_state_dict(),
|
||||
{
|
||||
"double_blocks.0.img_attn.norm.key_norm.scale",
|
||||
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
|
||||
@@ -1426,7 +1494,7 @@ class SD_1_2_XL_XLRefiner_DiffusersConfig(DiffusersConfigBase, MainConfigBase, M
|
||||
|
||||
_validate_class_name(
|
||||
cls,
|
||||
mod.common_config_paths(),
|
||||
common_config_paths(mod.path),
|
||||
{
|
||||
# SD 1.x and 2.x
|
||||
"StableDiffusionPipeline",
|
||||
@@ -1527,7 +1595,7 @@ class SD_3_DiffusersConfig(DiffusersConfigBase, MainConfigBase, ModelConfigBase)
|
||||
|
||||
_validate_class_name(
|
||||
cls,
|
||||
mod.common_config_paths(),
|
||||
common_config_paths(mod.path),
|
||||
{
|
||||
"StableDiffusion3Pipeline",
|
||||
"SD3Transformer2DModel",
|
||||
@@ -1548,7 +1616,7 @@ class SD_3_DiffusersConfig(DiffusersConfigBase, MainConfigBase, ModelConfigBase)
|
||||
@classmethod
|
||||
def _get_submodels_or_raise(cls, mod: ModelOnDisk) -> dict[SubModelType, SubmodelDefinition]:
|
||||
# Example: https://huggingface.co/stabilityai/stable-diffusion-3.5-medium/blob/main/model_index.json
|
||||
config = _get_config_or_raise(cls, mod.common_config_paths())
|
||||
config = _get_config_or_raise(cls, common_config_paths(mod.path))
|
||||
|
||||
submodels: dict[SubModelType, SubmodelDefinition] = {}
|
||||
|
||||
@@ -1601,8 +1669,10 @@ class CogView4_DiffusersConfig(DiffusersConfigBase, MainConfigBase, ModelConfigB
|
||||
|
||||
_validate_class_name(
|
||||
cls,
|
||||
mod.common_config_paths(),
|
||||
{"CogView4Pipeline"},
|
||||
common_config_paths(mod.path),
|
||||
{
|
||||
"CogView4Pipeline",
|
||||
},
|
||||
)
|
||||
|
||||
repo_variant = fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod)
|
||||
@@ -1706,13 +1776,14 @@ class IPAdapterCheckpointConfig(IPAdapterConfigBase, ModelConfigBase):
|
||||
|
||||
@classmethod
|
||||
def _validate_looks_like_ip_adapter(cls, mod: ModelOnDisk) -> None:
|
||||
if not mod.has_keys_starting_with(
|
||||
if not has_keys_starting_with(
|
||||
mod.load_state_dict(),
|
||||
{
|
||||
"image_proj.",
|
||||
"ip_adapter.",
|
||||
# XLabs FLUX IP-Adapter models have keys startinh with "ip_adapter_proj_model.".
|
||||
"ip_adapter_proj_model.",
|
||||
}
|
||||
},
|
||||
):
|
||||
raise NotAMatch(cls, "model does not match Checkpoint IP Adapter heuristics")
|
||||
|
||||
@@ -1778,7 +1849,7 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
|
||||
|
||||
_validate_class_name(
|
||||
cls,
|
||||
mod.common_config_paths(),
|
||||
common_config_paths(mod.path),
|
||||
{
|
||||
"CLIPModel",
|
||||
"CLIPTextModel",
|
||||
@@ -1792,7 +1863,7 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
|
||||
|
||||
@classmethod
|
||||
def _validate_clip_g_variant(cls, mod: ModelOnDisk) -> None:
|
||||
config = _get_config_or_raise(cls, mod.common_config_paths())
|
||||
config = _get_config_or_raise(cls, common_config_paths(mod.path))
|
||||
clip_variant = _get_clip_variant_type_from_config(config)
|
||||
|
||||
if clip_variant is not ClipVariantType.G:
|
||||
@@ -1816,7 +1887,7 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
|
||||
|
||||
_validate_class_name(
|
||||
cls,
|
||||
mod.common_config_paths(),
|
||||
common_config_paths(mod.path),
|
||||
{
|
||||
"CLIPModel",
|
||||
"CLIPTextModel",
|
||||
@@ -1830,7 +1901,7 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
|
||||
|
||||
@classmethod
|
||||
def _validate_clip_l_variant(cls, mod: ModelOnDisk) -> None:
|
||||
config = _get_config_or_raise(cls, mod.common_config_paths())
|
||||
config = _get_config_or_raise(cls, common_config_paths(mod.path))
|
||||
clip_variant = _get_clip_variant_type_from_config(config)
|
||||
|
||||
if clip_variant is not ClipVariantType.L:
|
||||
@@ -1852,7 +1923,7 @@ class CLIPVisionDiffusersConfig(DiffusersConfigBase, ModelConfigBase):
|
||||
|
||||
_validate_class_name(
|
||||
cls,
|
||||
mod.common_config_paths(),
|
||||
common_config_paths(mod.path),
|
||||
{
|
||||
"CLIPVisionModelWithProjection",
|
||||
},
|
||||
@@ -1882,7 +1953,7 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfi
|
||||
|
||||
_validate_class_name(
|
||||
cls,
|
||||
mod.common_config_paths(),
|
||||
common_config_paths(mod.path),
|
||||
{
|
||||
"T2IAdapter",
|
||||
},
|
||||
@@ -1894,7 +1965,7 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfi
|
||||
|
||||
@classmethod
|
||||
def _get_base_or_raise(cls, mod: ModelOnDisk) -> T2IAdapterDiffusers_SupportedBases:
|
||||
config = _get_config_or_raise(cls, mod.common_config_paths())
|
||||
config = _get_config_or_raise(cls, common_config_paths(mod.path))
|
||||
|
||||
adapter_type = config.get("adapter_type")
|
||||
|
||||
@@ -1955,7 +2026,7 @@ class SigLIPConfig(DiffusersConfigBase, ModelConfigBase):
|
||||
|
||||
_validate_class_name(
|
||||
cls,
|
||||
mod.common_config_paths(),
|
||||
common_config_paths(mod.path),
|
||||
{
|
||||
"SiglipModel",
|
||||
},
|
||||
@@ -1998,7 +2069,7 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
|
||||
|
||||
_validate_class_name(
|
||||
cls,
|
||||
mod.common_config_paths(),
|
||||
common_config_paths(mod.path),
|
||||
{
|
||||
"LlavaOnevisionForConditionalGeneration",
|
||||
},
|
||||
|
||||
@@ -128,26 +128,3 @@ class ModelOnDisk:
|
||||
f"Please specify the intended file using the 'path' argument"
|
||||
)
|
||||
return path
|
||||
|
||||
def has_keys_exact(self, keys: str | set[str], path: Optional[Path] = None) -> bool:
|
||||
_keys = {keys} if isinstance(keys, str) else keys
|
||||
state_dict = self.load_state_dict(path)
|
||||
return _keys.issubset({key for key in state_dict.keys() if isinstance(key, str)})
|
||||
|
||||
def has_keys_starting_with(self, prefixes: str | set[str], path: Optional[Path] = None) -> bool:
|
||||
_prefixes = {prefixes} if isinstance(prefixes, str) else prefixes
|
||||
state_dict = self.load_state_dict(path)
|
||||
return any(
|
||||
any(key.startswith(prefix) for prefix in _prefixes) for key in state_dict.keys() if isinstance(key, str)
|
||||
)
|
||||
|
||||
def has_keys_ending_with(self, suffixes: str | set[str], path: Optional[Path] = None) -> bool:
|
||||
_suffixes = {suffixes} if isinstance(suffixes, str) else suffixes
|
||||
state_dict = self.load_state_dict(path)
|
||||
return any(
|
||||
any(key.endswith(suffix) for suffix in _suffixes) for key in state_dict.keys() if isinstance(key, str)
|
||||
)
|
||||
|
||||
def common_config_paths(self) -> set[Path]:
|
||||
"""Returns common config file paths for models stored in directories."""
|
||||
return {self.path / "config.json", self.path / "model_index.json"}
|
||||
|
||||
Reference in New Issue
Block a user