tidy(mm): clean up model heuristic utils

This commit is contained in:
psychedelicious
2025-10-01 12:21:19 +10:00
parent 951635fbee
commit c53c731371
2 changed files with 124 additions and 76 deletions

View File

@@ -106,35 +106,67 @@ class NotAMatch(Exception):
DEFAULTS_PRECISION = Literal["fp16", "fp32"]
# Utility from https://github.com/pydantic/pydantic/discussions/7367#discussioncomment-14213144
def find_field_schema(model: type[BaseModel], field_name: str) -> CoreSchema:
schema: CoreSchema = model.__pydantic_core_schema__.copy()
# we shallow copied, be careful not to mutate the original schema!
class FieldValidator:
"""Utility class for validating individual fields of a Pydantic model without instantiating the whole model.
assert schema["type"] in ["definitions", "model"]
See: https://github.com/pydantic/pydantic/discussions/7367#discussioncomment-14213144
"""
# find the field schema
field_schema = schema["schema"] # type: ignore
while "fields" not in field_schema:
field_schema = field_schema["schema"] # type: ignore
@staticmethod
def find_field_schema(model: type[BaseModel], field_name: str) -> CoreSchema:
"""Find the Pydantic core schema for a specific field in a model."""
schema: CoreSchema = model.__pydantic_core_schema__.copy()
# we shallow copied, be careful not to mutate the original schema!
field_schema = field_schema["fields"][field_name]["schema"] # type: ignore
assert schema["type"] in ["definitions", "model"]
# if the original schema is a definition schema, replace the model schema with the field schema
if schema["type"] == "definitions":
schema["schema"] = field_schema
return schema
else:
return field_schema
# find the field schema
field_schema = schema["schema"] # type: ignore
while "fields" not in field_schema:
field_schema = field_schema["schema"] # type: ignore
field_schema = field_schema["fields"][field_name]["schema"] # type: ignore
# if the original schema is a definition schema, replace the model schema with the field schema
if schema["type"] == "definitions":
schema["schema"] = field_schema
return schema
else:
return field_schema
@cache
@staticmethod
def get_validator(model: type[BaseModel], field_name: str) -> SchemaValidator:
"""Get a SchemaValidator for a specific field in a model."""
return SchemaValidator(FieldValidator.find_field_schema(model, field_name))
@staticmethod
def validate_field(model: type[BaseModel], field_name: str, value: Any) -> Any:
"""Validate a value for a specific field in a model."""
return FieldValidator.get_validator(model, field_name).validate_python(value)
@cache
def validator(model: type[BaseModel], field_name: str) -> SchemaValidator:
return SchemaValidator(find_field_schema(model, field_name))
def has_keys_exact(state_dict: dict[str | int, Any], keys: str | set[str]) -> bool:
"""Returns true if the state dict has all of the specified keys."""
_keys = {keys} if isinstance(keys, str) else keys
return _keys.issubset({key for key in state_dict.keys() if isinstance(key, str)})
def validate_model_field(model: type[BaseModel], field_name: str, value: Any) -> Any:
return validator(model, field_name).validate_python(value)
def has_keys_starting_with(state_dict: dict[str | int, Any], prefixes: str | set[str]) -> bool:
"""Returns true if the state dict has any keys starting with any of the specified prefixes."""
_prefixes = {prefixes} if isinstance(prefixes, str) else prefixes
return any(any(key.startswith(prefix) for prefix in _prefixes) for key in state_dict.keys() if isinstance(key, str))
def has_keys_ending_with(state_dict: dict[str | int, Any], suffixes: str | set[str]) -> bool:
"""Returns true if the state dict has any keys ending with any of the specified suffixes."""
_suffixes = {suffixes} if isinstance(suffixes, str) else suffixes
return any(any(key.endswith(suffix) for suffix in _suffixes) for key in state_dict.keys() if isinstance(key, str))
def common_config_paths(path: Path) -> set[Path]:
"""Returns common config file paths for models stored in directories."""
return {path / "config.json", path / "model_index.json"}
# These utility functions are tightly coupled to the config classes below in order to make the process of raising
@@ -225,7 +257,7 @@ def _validate_override_fields(
if field_name not in config_class.model_fields:
raise NotAMatch(config_class, f"unknown override field: {field_name}")
try:
validate_model_field(config_class, field_name, override_value)
FieldValidator.validate_field(config_class, field_name, override_value)
except ValidationError as e:
raise NotAMatch(config_class, f"invalid override for field '{field_name}': {e}") from e
@@ -440,7 +472,13 @@ class T5EncoderConfig(ModelConfigBase):
_validate_override_fields(cls, fields)
_validate_class_name(cls, mod.common_config_paths(), {"T5EncoderModel"})
_validate_class_name(
cls,
common_config_paths(mod.path),
{
"T5EncoderModel",
},
)
cls._validate_has_unquantized_config_file(mod)
@@ -465,7 +503,13 @@ class T5EncoderBnbQuantizedLlmInt8bConfig(ModelConfigBase):
_validate_override_fields(cls, fields)
_validate_class_name(cls, mod.common_config_paths(), {"T5EncoderModel"})
_validate_class_name(
cls,
common_config_paths(mod.path),
{
"T5EncoderModel",
},
)
cls._validate_filename_looks_like_bnb_quantized(mod)
@@ -481,7 +525,7 @@ class T5EncoderBnbQuantizedLlmInt8bConfig(ModelConfigBase):
@classmethod
def _validate_model_looks_like_bnb_quantized(cls, mod: ModelOnDisk) -> None:
has_scb_key_suffix = mod.has_keys_ending_with("SCB")
has_scb_key_suffix = has_keys_ending_with(mod.load_state_dict(), "SCB")
if not has_scb_key_suffix:
raise NotAMatch(cls, "state dict does not look like bnb quantized llm_int8")
@@ -592,23 +636,25 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
def _validate_looks_like_lora(cls, mod: ModelOnDisk) -> None:
# Note: Existence of these key prefixes/suffixes does not guarantee that this is a LoRA.
# Some main models have these keys, likely due to the creator merging in a LoRA.
has_key_with_lora_prefix = mod.has_keys_starting_with(
has_key_with_lora_prefix = has_keys_starting_with(
mod.load_state_dict(),
{
"lora_te_",
"lora_unet_",
"lora_te1_",
"lora_te2_",
"lora_transformer_",
}
},
)
has_key_with_lora_suffix = mod.has_keys_ending_with(
has_key_with_lora_suffix = has_keys_ending_with(
mod.load_state_dict(),
{
"to_k_lora.up.weight",
"to_q_lora.down.weight",
"lora_A.weight",
"lora_B.weight",
}
},
)
if not has_key_with_lora_prefix and not has_key_with_lora_suffix:
@@ -754,7 +800,13 @@ class VAECheckpointConfig(CheckpointConfigBase, ModelConfigBase):
@classmethod
def _validate_looks_like_vae(cls, mod: ModelOnDisk) -> None:
if not mod.has_keys_starting_with({"encoder.conv_in", "decoder.conv_in"}):
if not has_keys_starting_with(
mod.load_state_dict(),
{
"encoder.conv_in",
"decoder.conv_in",
},
):
raise NotAMatch(cls, "model does not match Checkpoint VAE heuristics")
@classmethod
@@ -786,7 +838,14 @@ class VAEDiffusersConfig(DiffusersConfigBase, ModelConfigBase):
_validate_override_fields(cls, fields)
_validate_class_name(cls, mod.common_config_paths(), {"AutoencoderKL", "AutoencoderTiny"})
_validate_class_name(
cls,
common_config_paths(mod.path),
{
"AutoencoderKL",
"AutoencoderTiny",
},
)
base = fields.get("base") or cls._get_base_or_raise(mod)
return cls(**fields, base=base)
@@ -812,7 +871,7 @@ class VAEDiffusersConfig(DiffusersConfigBase, ModelConfigBase):
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> VAEDiffusersConfig_SupportedBases:
config = _get_config_or_raise(cls, mod.common_config_paths())
config = _get_config_or_raise(cls, common_config_paths(mod.path))
if cls._config_looks_like_sdxl(config):
return BaseModelType.StableDiffusionXL
elif cls._name_looks_like_sdxl(mod):
@@ -843,7 +902,14 @@ class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, M
_validate_override_fields(cls, fields)
_validate_class_name(cls, mod.common_config_paths(), {"ControlNetModel", "FluxControlNetModel"})
_validate_class_name(
cls,
common_config_paths(mod.path),
{
"ControlNetModel",
"FluxControlNetModel",
},
)
base = fields.get("base") or cls._get_base_or_raise(mod)
@@ -851,7 +917,7 @@ class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, M
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> ControlNetDiffusers_SupportedBases:
config = _get_config_or_raise(cls, mod.common_config_paths())
config = _get_config_or_raise(cls, common_config_paths(mod.path))
if config.get("_class_name") == "FluxControlNetModel":
return BaseModelType.Flux
@@ -900,7 +966,8 @@ class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase,
@classmethod
def _validate_looks_like_controlnet(cls, mod: ModelOnDisk) -> None:
if not mod.has_keys_starting_with(
if has_keys_starting_with(
mod.load_state_dict(),
{
"controlnet",
"control_model",
@@ -911,7 +978,7 @@ class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase,
# "double_blocks.", which we check for above. But, I'm afraid to modify this logic because it is so
# delicate.
"controlnet_blocks",
}
},
):
raise NotAMatch(cls, "state dict does not look like a ControlNet checkpoint")
@@ -1268,7 +1335,8 @@ class FLUX_Unquantized_CheckpointConfig(CheckpointConfigBase, MainConfigBase, Mo
@classmethod
def _validate_is_flux(cls, mod: ModelOnDisk) -> None:
if not mod.has_keys_exact(
if not has_keys_exact(
mod.load_state_dict(),
{
"double_blocks.0.img_attn.norm.key_norm.scale",
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
@@ -1426,7 +1494,7 @@ class SD_1_2_XL_XLRefiner_DiffusersConfig(DiffusersConfigBase, MainConfigBase, M
_validate_class_name(
cls,
mod.common_config_paths(),
common_config_paths(mod.path),
{
# SD 1.x and 2.x
"StableDiffusionPipeline",
@@ -1527,7 +1595,7 @@ class SD_3_DiffusersConfig(DiffusersConfigBase, MainConfigBase, ModelConfigBase)
_validate_class_name(
cls,
mod.common_config_paths(),
common_config_paths(mod.path),
{
"StableDiffusion3Pipeline",
"SD3Transformer2DModel",
@@ -1548,7 +1616,7 @@ class SD_3_DiffusersConfig(DiffusersConfigBase, MainConfigBase, ModelConfigBase)
@classmethod
def _get_submodels_or_raise(cls, mod: ModelOnDisk) -> dict[SubModelType, SubmodelDefinition]:
# Example: https://huggingface.co/stabilityai/stable-diffusion-3.5-medium/blob/main/model_index.json
config = _get_config_or_raise(cls, mod.common_config_paths())
config = _get_config_or_raise(cls, common_config_paths(mod.path))
submodels: dict[SubModelType, SubmodelDefinition] = {}
@@ -1601,8 +1669,10 @@ class CogView4_DiffusersConfig(DiffusersConfigBase, MainConfigBase, ModelConfigB
_validate_class_name(
cls,
mod.common_config_paths(),
{"CogView4Pipeline"},
common_config_paths(mod.path),
{
"CogView4Pipeline",
},
)
repo_variant = fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod)
@@ -1706,13 +1776,14 @@ class IPAdapterCheckpointConfig(IPAdapterConfigBase, ModelConfigBase):
@classmethod
def _validate_looks_like_ip_adapter(cls, mod: ModelOnDisk) -> None:
if not mod.has_keys_starting_with(
if not has_keys_starting_with(
mod.load_state_dict(),
{
"image_proj.",
"ip_adapter.",
# XLabs FLUX IP-Adapter models have keys startinh with "ip_adapter_proj_model.".
"ip_adapter_proj_model.",
}
},
):
raise NotAMatch(cls, "model does not match Checkpoint IP Adapter heuristics")
@@ -1778,7 +1849,7 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
_validate_class_name(
cls,
mod.common_config_paths(),
common_config_paths(mod.path),
{
"CLIPModel",
"CLIPTextModel",
@@ -1792,7 +1863,7 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
@classmethod
def _validate_clip_g_variant(cls, mod: ModelOnDisk) -> None:
config = _get_config_or_raise(cls, mod.common_config_paths())
config = _get_config_or_raise(cls, common_config_paths(mod.path))
clip_variant = _get_clip_variant_type_from_config(config)
if clip_variant is not ClipVariantType.G:
@@ -1816,7 +1887,7 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
_validate_class_name(
cls,
mod.common_config_paths(),
common_config_paths(mod.path),
{
"CLIPModel",
"CLIPTextModel",
@@ -1830,7 +1901,7 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
@classmethod
def _validate_clip_l_variant(cls, mod: ModelOnDisk) -> None:
config = _get_config_or_raise(cls, mod.common_config_paths())
config = _get_config_or_raise(cls, common_config_paths(mod.path))
clip_variant = _get_clip_variant_type_from_config(config)
if clip_variant is not ClipVariantType.L:
@@ -1852,7 +1923,7 @@ class CLIPVisionDiffusersConfig(DiffusersConfigBase, ModelConfigBase):
_validate_class_name(
cls,
mod.common_config_paths(),
common_config_paths(mod.path),
{
"CLIPVisionModelWithProjection",
},
@@ -1882,7 +1953,7 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfi
_validate_class_name(
cls,
mod.common_config_paths(),
common_config_paths(mod.path),
{
"T2IAdapter",
},
@@ -1894,7 +1965,7 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfi
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> T2IAdapterDiffusers_SupportedBases:
config = _get_config_or_raise(cls, mod.common_config_paths())
config = _get_config_or_raise(cls, common_config_paths(mod.path))
adapter_type = config.get("adapter_type")
@@ -1955,7 +2026,7 @@ class SigLIPConfig(DiffusersConfigBase, ModelConfigBase):
_validate_class_name(
cls,
mod.common_config_paths(),
common_config_paths(mod.path),
{
"SiglipModel",
},
@@ -1998,7 +2069,7 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
_validate_class_name(
cls,
mod.common_config_paths(),
common_config_paths(mod.path),
{
"LlavaOnevisionForConditionalGeneration",
},

View File

@@ -128,26 +128,3 @@ class ModelOnDisk:
f"Please specify the intended file using the 'path' argument"
)
return path
def has_keys_exact(self, keys: str | set[str], path: Optional[Path] = None) -> bool:
_keys = {keys} if isinstance(keys, str) else keys
state_dict = self.load_state_dict(path)
return _keys.issubset({key for key in state_dict.keys() if isinstance(key, str)})
def has_keys_starting_with(self, prefixes: str | set[str], path: Optional[Path] = None) -> bool:
_prefixes = {prefixes} if isinstance(prefixes, str) else prefixes
state_dict = self.load_state_dict(path)
return any(
any(key.startswith(prefix) for prefix in _prefixes) for key in state_dict.keys() if isinstance(key, str)
)
def has_keys_ending_with(self, suffixes: str | set[str], path: Optional[Path] = None) -> bool:
_suffixes = {suffixes} if isinstance(suffixes, str) else suffixes
state_dict = self.load_state_dict(path)
return any(
any(key.endswith(suffix) for suffix in _suffixes) for key in state_dict.keys() if isinstance(key, str)
)
def common_config_paths(self) -> set[Path]:
"""Returns common config file paths for models stored in directories."""
return {self.path / "config.json", self.path / "model_index.json"}